Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +3 -0
- .gradio/certificate.pem +31 -0
- README.md +224 -12
- dataset_utils.py +203 -0
- gguf_utils.py +18 -0
- inference_utils.py +25 -0
- main.py +5 -0
- model_utils.py +60 -0
- openai_sample_dataset.json +178 -0
- project_plan.md +104 -0
- requirements.txt +15 -0
- training_utils.py +75 -0
- ui.py +231 -0
- unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
- unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
- unsloth_compiled_cache/UnslothCPOTrainer.py +1557 -0
- unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
- unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothGKDTrainer.py +863 -0
- unsloth_compiled_cache/UnslothGRPOTrainer.py +1438 -0
- unsloth_compiled_cache/UnslothKTOTrainer.py +1840 -0
- unsloth_compiled_cache/UnslothNashMDTrainer.py +955 -0
- unsloth_compiled_cache/UnslothORPOTrainer.py +1543 -0
- unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1269 -0
- unsloth_compiled_cache/UnslothPPOTrainer.py +1259 -0
- unsloth_compiled_cache/UnslothPRMTrainer.py +800 -0
- unsloth_compiled_cache/UnslothRLOOTrainer.py +1133 -0
- unsloth_compiled_cache/UnslothRewardTrainer.py +819 -0
- unsloth_compiled_cache/UnslothSFTTrainer.py +1031 -0
- unsloth_compiled_cache/UnslothXPOTrainer.py +1010 -0
- unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +3 -0
- unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
- upload_utils.py +77 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
outputs/
|
| 3 |
+
results/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
|
@@ -1,12 +1,224 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Finetune
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Finetune-Test
|
| 3 |
+
app_file: ui.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.23.2
|
| 6 |
+
---
|
| 7 |
+
# LLM Finetuner
|
| 8 |
+
|
| 9 |
+
This project provides a user-friendly interface for fine-tuning Large Language Models (LLMs) using the Unsloth library. It includes features for dataset preparation, synthetic dataset creation, model training, testing, and GGUF conversion.
|
| 10 |
+
|
| 11 |
+
## Features
|
| 12 |
+
|
| 13 |
+
- Load and fine-tune various pre-trained models
|
| 14 |
+
- Prepare existing datasets or create synthetic datasets
|
| 15 |
+
- Fine-tune models with customizable hyperparameters
|
| 16 |
+
- Test fine-tuned models
|
| 17 |
+
- Convert models to GGUF format for deployment
|
| 18 |
+
|
| 19 |
+
## Prerequisites
|
| 20 |
+
|
| 21 |
+
- Python 3.8 or higher
|
| 22 |
+
- CUDA-capable GPU (for efficient training)
|
| 23 |
+
|
| 24 |
+
## Installation
|
| 25 |
+
|
| 26 |
+
1. Clone the repository:
|
| 27 |
+
```
|
| 28 |
+
git clone https://github.com/yourusername/llm-finetuner.git
|
| 29 |
+
cd llm-finetuner
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
2. Create a virtual environment (optional but recommended):
|
| 33 |
+
```
|
| 34 |
+
python -m venv venv
|
| 35 |
+
source venv/bin/activate # On Windows, use `venv\Scripts\activate`
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
3. Install the required packages:
|
| 39 |
+
```
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
1. Run the application:
|
| 46 |
+
```
|
| 47 |
+
python main.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
2. Open the provided URL in your web browser to access the Gradio interface.
|
| 51 |
+
|
| 52 |
+
3. Follow these steps in the interface:
|
| 53 |
+
a. Settings: Enter your Hugging Face token and select a model.
|
| 54 |
+
b. Dataset: Prepare an existing dataset or create a synthetic one.
|
| 55 |
+
c. Training: Set hyperparameters and start the fine-tuning process.
|
| 56 |
+
d. Test: Test your fine-tuned model with custom inputs.
|
| 57 |
+
e. GGUF Conversion: Convert your model to GGUF format if needed.
|
| 58 |
+
|
| 59 |
+
## Notes
|
| 60 |
+
|
| 61 |
+
- Ensure you have the necessary API keys for OpenAI or Anthropic if you plan to use them for synthetic dataset creation.
|
| 62 |
+
- If using Ollama for local generation, make sure it's installed and running on your machine.
|
| 63 |
+
- Fine-tuning can be computationally intensive. Ensure you have adequate GPU resources available.
|
| 64 |
+
|
| 65 |
+
## Contributing
|
| 66 |
+
|
| 67 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
| 68 |
+
|
| 69 |
+
## License
|
| 70 |
+
|
| 71 |
+
This project is licensed under the MIT License.
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Comprehensive Python Setup Guide
|
| 75 |
+
|
| 76 |
+
This guide will walk you through setting up Python, creating a virtual environment, and running your LLM Finetuner project on a new system.
|
| 77 |
+
|
| 78 |
+
## 1. Install Python
|
| 79 |
+
|
| 80 |
+
### Windows:
|
| 81 |
+
1. Go to https://www.python.org/downloads/windows/
|
| 82 |
+
2. Download the latest Python 3.x installer (64-bit version recommended)
|
| 83 |
+
3. Run the installer
|
| 84 |
+
4. Check "Add Python to PATH" during installation
|
| 85 |
+
5. Click "Install Now"
|
| 86 |
+
|
| 87 |
+
### macOS:
|
| 88 |
+
1. Install Homebrew if you haven't already:
|
| 89 |
+
```
|
| 90 |
+
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
| 91 |
+
```
|
| 92 |
+
2. Install Python using Homebrew:
|
| 93 |
+
```
|
| 94 |
+
brew install python
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Linux (Ubuntu/Debian):
|
| 98 |
+
1. Update package list:
|
| 99 |
+
```
|
| 100 |
+
sudo apt update
|
| 101 |
+
```
|
| 102 |
+
2. Install Python:
|
| 103 |
+
```
|
| 104 |
+
sudo apt install python3 python3-pip python3-venv
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## 2. Verify Python Installation
|
| 108 |
+
|
| 109 |
+
Open a terminal (Command Prompt on Windows) and run:
|
| 110 |
+
```
|
| 111 |
+
python --version
|
| 112 |
+
```
|
| 113 |
+
You should see the Python version number. If not, try `python3 --version`.
|
| 114 |
+
|
| 115 |
+
## 3. Install Git
|
| 116 |
+
|
| 117 |
+
### Windows:
|
| 118 |
+
1. Go to https://git-scm.com/download/win
|
| 119 |
+
2. Download and run the installer
|
| 120 |
+
3. Use the default settings during installation
|
| 121 |
+
|
| 122 |
+
### macOS:
|
| 123 |
+
If you installed Homebrew earlier:
|
| 124 |
+
```
|
| 125 |
+
brew install git
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Linux (Ubuntu/Debian):
|
| 129 |
+
```
|
| 130 |
+
sudo apt install git
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## 4. Clone the Repository
|
| 134 |
+
|
| 135 |
+
1. Open a terminal
|
| 136 |
+
2. Navigate to where you want to store the project
|
| 137 |
+
3. Clone the repository:
|
| 138 |
+
```
|
| 139 |
+
git clone https://github.com/yourusername/llm-finetuner.git
|
| 140 |
+
cd llm-finetuner
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## 5. Create and Activate a Virtual Environment
|
| 144 |
+
|
| 145 |
+
### Windows:
|
| 146 |
+
```
|
| 147 |
+
python -m venv venv
|
| 148 |
+
venv\Scripts\activate
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### macOS and Linux:
|
| 152 |
+
```
|
| 153 |
+
python3 -m venv venv
|
| 154 |
+
source venv/bin/activate
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Your prompt should change to indicate that the virtual environment is active.
|
| 158 |
+
|
| 159 |
+
## 6. Install Required Packages
|
| 160 |
+
|
| 161 |
+
With the virtual environment activated:
|
| 162 |
+
```
|
| 163 |
+
pip install -r requirements.txt
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
This may take a while as it installs all necessary dependencies.
|
| 167 |
+
|
| 168 |
+
## 7. Set Up CUDA (for GPU support)
|
| 169 |
+
|
| 170 |
+
If you have an NVIDIA GPU and want to use it for training:
|
| 171 |
+
|
| 172 |
+
1. Go to https://developer.nvidia.com/cuda-downloads
|
| 173 |
+
2. Download and install the CUDA Toolkit appropriate for your system
|
| 174 |
+
3. Install the cuDNN library:
|
| 175 |
+
- Go to https://developer.nvidia.com/cudnn
|
| 176 |
+
- Download cuDNN (you may need to create an NVIDIA account)
|
| 177 |
+
- Follow the installation instructions for your system
|
| 178 |
+
|
| 179 |
+
## 8. Run the Application
|
| 180 |
+
|
| 181 |
+
With the virtual environment still activated:
|
| 182 |
+
```
|
| 183 |
+
python main.py
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
This will start the Gradio interface. Open the provided URL in your web browser.
|
| 187 |
+
|
| 188 |
+
## 9. Using the LLM Finetuner
|
| 189 |
+
|
| 190 |
+
1. In the "Settings" tab:
|
| 191 |
+
- Enter your Hugging Face token
|
| 192 |
+
- Select a model
|
| 193 |
+
|
| 194 |
+
2. In the "Dataset" tab:
|
| 195 |
+
- Prepare an existing dataset or create a synthetic one
|
| 196 |
+
|
| 197 |
+
3. In the "Training" tab:
|
| 198 |
+
- Set hyperparameters and start training
|
| 199 |
+
|
| 200 |
+
4. In the "Test" tab:
|
| 201 |
+
- Test your fine-tuned model
|
| 202 |
+
|
| 203 |
+
5. In the "GGUF Conversion" tab:
|
| 204 |
+
- Convert your model to GGUF format if needed
|
| 205 |
+
|
| 206 |
+
## Troubleshooting
|
| 207 |
+
|
| 208 |
+
- If `python` doesn't work, try `python3`
|
| 209 |
+
- Ensure your GPU drivers are up to date for CUDA support
|
| 210 |
+
- If you encounter "command not found" errors, ensure the relevant programs are in your system's PATH
|
| 211 |
+
|
| 212 |
+
## Closing Notes
|
| 213 |
+
|
| 214 |
+
- Always activate the virtual environment before running the project
|
| 215 |
+
- To deactivate the virtual environment, simply type `deactivate` in the terminal
|
| 216 |
+
- Keep your Python packages updated with `pip install --upgrade -r requirements.txt`
|
| 217 |
+
|
| 218 |
+
Remember to keep your API keys and tokens secure. Happy fine-tuning!
|
| 219 |
+
|
| 220 |
+
## Alternative, installation
|
| 221 |
+
|
| 222 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 223 |
+
pip install triton
|
| 224 |
+
pip install unsloth gradio transformers datasets tqdm
|
dataset_utils.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, Dataset
|
| 2 |
+
import json
|
| 3 |
+
import csv
|
| 4 |
+
import openai
|
| 5 |
+
import anthropic
|
| 6 |
+
import requests
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def prepare_dataset(dataset_source, dataset_path, tokenizer, hf_token=None):
|
| 16 |
+
"""
|
| 17 |
+
Prepare a dataset for fine-tuning, either from Hugging Face or a local file.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dataset_source (str): 'huggingface' or 'local'
|
| 21 |
+
dataset_path (str): Path or identifier of the dataset
|
| 22 |
+
tokenizer: The tokenizer associated with the model
|
| 23 |
+
hf_token (str, optional): Hugging Face token for accessing datasets
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Dataset: Prepared dataset ready for fine-tuning
|
| 27 |
+
"""
|
| 28 |
+
if dataset_source == 'huggingface':
|
| 29 |
+
try:
|
| 30 |
+
dataset = load_dataset(dataset_path, split="train", use_auth_token=hf_token)
|
| 31 |
+
except ValueError:
|
| 32 |
+
# If use_auth_token is not supported, try without it
|
| 33 |
+
dataset = load_dataset(dataset_path, split="train")
|
| 34 |
+
elif dataset_source == 'local':
|
| 35 |
+
if not os.path.exists(dataset_path):
|
| 36 |
+
raise FileNotFoundError(f"File not found: {dataset_path}")
|
| 37 |
+
|
| 38 |
+
if dataset_path.endswith('.json'):
|
| 39 |
+
with open(dataset_path, 'r') as f:
|
| 40 |
+
data = json.load(f)
|
| 41 |
+
if isinstance(data, list):
|
| 42 |
+
dataset = Dataset.from_list(data)
|
| 43 |
+
elif isinstance(data, dict):
|
| 44 |
+
dataset = Dataset.from_dict(data)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError("JSON file must contain either a list or a dictionary.")
|
| 47 |
+
elif dataset_path.endswith('.csv'):
|
| 48 |
+
with open(dataset_path, 'r') as f:
|
| 49 |
+
reader = csv.DictReader(f)
|
| 50 |
+
data = list(reader)
|
| 51 |
+
dataset = Dataset.from_list(data)
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Unsupported file format. Please use JSON or CSV.")
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError("Invalid dataset source. Use 'huggingface' or 'local'.")
|
| 56 |
+
|
| 57 |
+
# Check if 'conversations' column exists, if not, try to create it
|
| 58 |
+
if 'conversations' not in dataset.column_names:
|
| 59 |
+
if 'text' in dataset.column_names:
|
| 60 |
+
dataset = dataset.map(lambda example: {'conversations': [{'human': example['text'], 'assistant': ''}]})
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Dataset does not contain 'conversations' or 'text' column. Please check your dataset structure.")
|
| 63 |
+
|
| 64 |
+
# Only apply standardize_sharegpt if 'conversations' column exists
|
| 65 |
+
if 'conversations' in dataset.column_names:
|
| 66 |
+
dataset = standardize_sharegpt(dataset)
|
| 67 |
+
|
| 68 |
+
def formatting_prompts_func(examples):
|
| 69 |
+
if tokenizer is None:
|
| 70 |
+
raise ValueError("Tokenizer is not properly initialized. Please load the model and tokenizer before preparing the dataset.")
|
| 71 |
+
convos = examples["conversations"]
|
| 72 |
+
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
| 73 |
+
return {"text": texts}
|
| 74 |
+
|
| 75 |
+
dataset = dataset.map(formatting_prompts_func, batched=True)
|
| 76 |
+
|
| 77 |
+
if 'text' not in dataset.column_names:
|
| 78 |
+
def format_conversation(example):
|
| 79 |
+
formatted_text = ""
|
| 80 |
+
for turn in example['conversations']:
|
| 81 |
+
formatted_text += f"{turn['role']}: {turn['content']}\n"
|
| 82 |
+
return {"text": formatted_text.strip()}
|
| 83 |
+
|
| 84 |
+
dataset = dataset.map(format_conversation)
|
| 85 |
+
|
| 86 |
+
return dataset
|
| 87 |
+
|
| 88 |
+
def standardize_sharegpt(dataset):
|
| 89 |
+
# This is a simplified version. You might need to adjust it based on your specific needs.
|
| 90 |
+
def process_conversation(conversation):
|
| 91 |
+
standardized = []
|
| 92 |
+
for turn in conversation:
|
| 93 |
+
if 'human' in turn:
|
| 94 |
+
standardized.append({'role': 'user', 'content': turn['human']})
|
| 95 |
+
if 'assistant' in turn:
|
| 96 |
+
standardized.append({'role': 'assistant', 'content': turn['assistant']})
|
| 97 |
+
return standardized
|
| 98 |
+
|
| 99 |
+
return dataset.map(lambda x: {'conversations': process_conversation(x['conversations'])})
|
| 100 |
+
|
| 101 |
+
def create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, model_name=None):
|
| 102 |
+
"""
|
| 103 |
+
Create a synthetic dataset based on example conversations and expected structure.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
examples (str): Example conversations to base the synthetic data on.
|
| 107 |
+
expected_structure (str): Description of the expected dataset structure.
|
| 108 |
+
num_samples (int): Number of synthetic samples to generate.
|
| 109 |
+
ai_provider (str): AI provider to use for generation ('OpenAI', 'Anthropic', or 'Ollama').
|
| 110 |
+
api_key (str): API key for the chosen AI provider.
|
| 111 |
+
model_name (str, optional): Model name for Ollama (if applicable).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Dataset: Synthetic dataset ready for fine-tuning.
|
| 115 |
+
"""
|
| 116 |
+
synthetic_data = []
|
| 117 |
+
|
| 118 |
+
prompt = f"""
|
| 119 |
+
You are an AI assistant creating training dataset for finetuning a model.
|
| 120 |
+
You are provided an one-shot or few-shot output example of output that application expects from the AI model. You are also provided the
|
| 121 |
+
expected structure that the to-be trained AI model expects during training process.
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Examples:
|
| 125 |
+
{examples}
|
| 126 |
+
|
| 127 |
+
Expected structure:
|
| 128 |
+
{expected_structure}
|
| 129 |
+
|
| 130 |
+
Please help Generate a new dataset in the provided same style and expected structure. Do not produce any extra output except the dataset in the training needed structure:
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
if ai_provider == "OpenAI":
|
| 134 |
+
client = openai.OpenAI(api_key=api_key)
|
| 135 |
+
for _ in tqdm(range(num_samples), desc="Generating samples"):
|
| 136 |
+
try:
|
| 137 |
+
response = client.chat.completions.create(
|
| 138 |
+
model="gpt-4-0125-preview",
|
| 139 |
+
messages=[{"role": "user", "content": prompt}],
|
| 140 |
+
timeout=30 # 30 seconds timeout
|
| 141 |
+
)
|
| 142 |
+
conversation = response.choices[0].message.content
|
| 143 |
+
synthetic_data.append({"conversations": json.loads(conversation)})
|
| 144 |
+
except json.JSONDecodeError:
|
| 145 |
+
logger.warning(f"Failed to decode response as JSON: {response.choices[0].message.content}")
|
| 146 |
+
except openai.APITimeoutError:
|
| 147 |
+
logger.warning("OpenAI API request timed out")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Unexpected error: {str(e)}")
|
| 150 |
+
time.sleep(1) # Rate limiting
|
| 151 |
+
|
| 152 |
+
elif ai_provider == "Anthropic":
|
| 153 |
+
client = anthropic.Anthropic(api_key=api_key)
|
| 154 |
+
for _ in tqdm(range(num_samples), desc="Generating samples"):
|
| 155 |
+
try:
|
| 156 |
+
response = client.completions.create(
|
| 157 |
+
model="claude-3-opus-20240229",
|
| 158 |
+
prompt=f"Human: {prompt}\n\nAssistant:",
|
| 159 |
+
max_tokens_to_sample=1000,
|
| 160 |
+
timeout=30 # 30 seconds timeout
|
| 161 |
+
)
|
| 162 |
+
synthetic_data.append({"conversations": json.loads(response.completion)})
|
| 163 |
+
except json.JSONDecodeError:
|
| 164 |
+
logger.warning(f"Failed to decode response as JSON: {response.completion}")
|
| 165 |
+
except anthropic.APITimeoutError:
|
| 166 |
+
logger.warning("Anthropic API request timed out")
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Unexpected error: {str(e)}")
|
| 169 |
+
time.sleep(1) # Rate limiting
|
| 170 |
+
|
| 171 |
+
elif ai_provider == "Ollama":
|
| 172 |
+
for _ in tqdm(range(num_samples), desc="Generating samples"):
|
| 173 |
+
try:
|
| 174 |
+
response = requests.post('http://localhost:11434/api/generate',
|
| 175 |
+
json={
|
| 176 |
+
"model": model_name,
|
| 177 |
+
"prompt": prompt,
|
| 178 |
+
"stream": False
|
| 179 |
+
},
|
| 180 |
+
timeout=30) # 30 seconds timeout
|
| 181 |
+
response.raise_for_status()
|
| 182 |
+
synthetic_data.append({"conversations": json.loads(response.json()["response"])})
|
| 183 |
+
except json.JSONDecodeError:
|
| 184 |
+
logger.warning(f"Failed to decode response as JSON: {response.json()['response']}")
|
| 185 |
+
except requests.Timeout:
|
| 186 |
+
logger.warning("Ollama API request timed out")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.error(f"Unexpected error: {str(e)}")
|
| 189 |
+
time.sleep(1) # Rate limiting
|
| 190 |
+
|
| 191 |
+
dataset = Dataset.from_list(synthetic_data)
|
| 192 |
+
dataset = standardize_sharegpt(dataset)
|
| 193 |
+
|
| 194 |
+
if 'text' not in dataset.column_names:
|
| 195 |
+
def format_conversation(example):
|
| 196 |
+
formatted_text = ""
|
| 197 |
+
for turn in example['conversations']:
|
| 198 |
+
formatted_text += f"{turn['role']}: {turn['content']}\n"
|
| 199 |
+
return {"text": formatted_text.strip()}
|
| 200 |
+
|
| 201 |
+
dataset = dataset.map(format_conversation)
|
| 202 |
+
|
| 203 |
+
return dataset
|
gguf_utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def convert_to_gguf(model, tokenizer, output_path, quantization_method="q8_0"):
|
| 2 |
+
"""
|
| 3 |
+
Convert the fine-tuned model to GGUF format.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
model: The fine-tuned model to convert.
|
| 7 |
+
tokenizer: The tokenizer associated with the model.
|
| 8 |
+
output_path (str): The path to save the converted model.
|
| 9 |
+
quantization_method (str): The quantization method to use (e.g., "q8_0", "q4_k_m", "q5_k_m", "f16").
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
str: A message indicating the success or failure of the conversion.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
model.save_pretrained_gguf(output_path, tokenizer, quantization_method=quantization_method)
|
| 16 |
+
return f"Model successfully converted to GGUF format: {output_path}-unsloth-{quantization_method}.gguf"
|
| 17 |
+
except Exception as e:
|
| 18 |
+
return f"Error converting to GGUF: {str(e)}"
|
inference_utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def test_model(model, tokenizer, input_text):
|
| 2 |
+
"""
|
| 3 |
+
Test the fine-tuned model with a given input.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
model: The fine-tuned model to test.
|
| 7 |
+
tokenizer: The tokenizer associated with the model.
|
| 8 |
+
input_text (str): The input text to generate a response for.
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
str: The generated response from the model.
|
| 12 |
+
"""
|
| 13 |
+
messages = [
|
| 14 |
+
{"role": "user", "content": input_text},
|
| 15 |
+
]
|
| 16 |
+
inputs = tokenizer.apply_chat_template(
|
| 17 |
+
messages,
|
| 18 |
+
tokenize=True,
|
| 19 |
+
add_generation_prompt=True,
|
| 20 |
+
return_tensors="pt"
|
| 21 |
+
).to("cuda")
|
| 22 |
+
|
| 23 |
+
outputs = model.generate(input_ids=inputs, max_new_tokens=64, use_cache=True,
|
| 24 |
+
temperature=1.5, min_p=0.1)
|
| 25 |
+
return tokenizer.batch_decode(outputs)[0]
|
main.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ui import create_gradio_interface
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
demo = create_gradio_interface()
|
| 5 |
+
demo.launch(share=True)
|
model_utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import importlib.util
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
def load_model(model_path, hf_token):
|
| 6 |
+
"""
|
| 7 |
+
Load a pre-trained model and tokenizer, using unsloth if available,
|
| 8 |
+
falling back to standard transformers if necessary.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
model_path (str): Path or identifier of the pre-trained model.
|
| 12 |
+
hf_token (str): Hugging Face API token for accessing gated models.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
tuple: Loaded model and tokenizer.
|
| 16 |
+
"""
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
|
| 18 |
+
|
| 19 |
+
# Check if CUDA is available
|
| 20 |
+
cuda_available = torch.cuda.is_available()
|
| 21 |
+
if cuda_available:
|
| 22 |
+
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
|
| 23 |
+
device = "cuda"
|
| 24 |
+
else:
|
| 25 |
+
print("CUDA is not available. Using CPU.")
|
| 26 |
+
device = "cpu"
|
| 27 |
+
|
| 28 |
+
# Try to use unsloth if it's available
|
| 29 |
+
if importlib.util.find_spec("unsloth") is not None:
|
| 30 |
+
try:
|
| 31 |
+
from unsloth import FastLanguageModel
|
| 32 |
+
print("Using unsloth for model loading.")
|
| 33 |
+
model, _ = FastLanguageModel.from_pretrained(
|
| 34 |
+
model_name=model_path,
|
| 35 |
+
max_seq_length=2048,
|
| 36 |
+
dtype=None, # Automatically choose between float16 and bfloat16
|
| 37 |
+
load_in_4bit=cuda_available, # Only use 4-bit quantization if CUDA is available
|
| 38 |
+
token=hf_token
|
| 39 |
+
)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Error loading model with unsloth: {e}")
|
| 42 |
+
print("Falling back to standard transformers.")
|
| 43 |
+
model = load_with_transformers(model_path, hf_token, device)
|
| 44 |
+
else:
|
| 45 |
+
print("unsloth not found. Using standard transformers.")
|
| 46 |
+
model = load_with_transformers(model_path, hf_token, device)
|
| 47 |
+
|
| 48 |
+
# Do not use .to(device) for quantized models
|
| 49 |
+
# The device placement is handled automatically by unsloth or transformers
|
| 50 |
+
|
| 51 |
+
return model, tokenizer
|
| 52 |
+
|
| 53 |
+
def load_with_transformers(model_path, hf_token, device):
|
| 54 |
+
"""Helper function to load model with standard transformers library."""
|
| 55 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 56 |
+
model_path,
|
| 57 |
+
device_map="auto", # This will handle device placement automatically
|
| 58 |
+
token=hf_token,
|
| 59 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 60 |
+
)
|
openai_sample_dataset.json
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"conversations": [
|
| 4 |
+
{
|
| 5 |
+
"human": "I'd like to know the price of silver membership.",
|
| 6 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"PriceInquiry\", \"operation_id\": \"fetchMembershipPrice\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"membership\": \"Silver-Gym-Membership\"}}"
|
| 7 |
+
}
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"conversations": [
|
| 12 |
+
{
|
| 13 |
+
"human": "Help me book an appointment.",
|
| 14 |
+
"assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"OnsiteAppointment\", \"operation_id\": \"bookAppointment\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"url\": \"book-appointment\"}, \"tool_body_parameters\": {\"date\": \"next Monday\", \"time\": \"10:00 AM\"}}"
|
| 15 |
+
}
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"conversations": [
|
| 20 |
+
{
|
| 21 |
+
"human": "I'm having some back pain that's why considering the gym.",
|
| 22 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"no\"}"
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"conversations": [
|
| 28 |
+
{
|
| 29 |
+
"human": "Yes, I'd like to know the yoga classes schedule.",
|
| 30 |
+
"assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"yoga\"}, \"tool_body_parameters\": {}}"
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"conversations": [
|
| 36 |
+
{
|
| 37 |
+
"human": "Help me book a trial session tomorrow.",
|
| 38 |
+
"assistant": "{\"conversation_stage_id\": 6,\"tool_required\": \"yes\",\"tool_name\": \"ScheduleTrialSession\", \"operation_id\": \"bookTrialSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"url\": \"schedule-trial\"}, \"tool_body_parameters\": {\"date\": \"tomorrow\", \"time\": \"10:00 AM\"}}"
|
| 39 |
+
}
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"conversations": [
|
| 44 |
+
{
|
| 45 |
+
"human": "I want to lose weight and build muscle.",
|
| 46 |
+
"assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"no\"}"
|
| 47 |
+
}
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"conversations": [
|
| 52 |
+
{
|
| 53 |
+
"human": "What are the benefits of the Gold membership?",
|
| 54 |
+
"assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
|
| 55 |
+
}
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"conversations": [
|
| 60 |
+
{
|
| 61 |
+
"human": "Yes, I am available next Monday at 10 AM. Help me book a personal training session.",
|
| 62 |
+
"assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"PersonalTraining\", \"operation_id\": \"bookTrainerSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\"}, \"tool_body_parameters\": {\"date\": \"next Monday\", \"time\": \"10:00 AM\"}}"
|
| 63 |
+
}
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"conversations": [
|
| 68 |
+
{
|
| 69 |
+
"human": "I am free tomorrow afternoon.",
|
| 70 |
+
"assistant": "{\"conversation_stage_id\": 6,\"tool_required\": \"yes\",\"tool_name\": \"GymAppointmentAvailability\", \"operation_id\": \"checkAvailability\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"date\": \"tomorrow\", \"time\": \"afternoon\"}, \"tool_body_parameters\": {}}"
|
| 71 |
+
}
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"conversations": [
|
| 76 |
+
{
|
| 77 |
+
"human": "Yes, I'd like to join the spin fitness class.",
|
| 78 |
+
"assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"spin\"}, \"tool_body_parameters\": {}}"
|
| 79 |
+
}
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"conversations": [
|
| 84 |
+
{
|
| 85 |
+
"human": "Sure, when do zumba classes start?",
|
| 86 |
+
"assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"zumba\"}, \"tool_body_parameters\": {}}"
|
| 87 |
+
}
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"conversations": [
|
| 92 |
+
{
|
| 93 |
+
"human": "I'd like to know if there are any membership discounts available.",
|
| 94 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"DiscountInquiry\", \"operation_id\": \"checkDiscount\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"membership_type\": \"any\"}}"
|
| 95 |
+
}
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"conversations": [
|
| 100 |
+
{
|
| 101 |
+
"human": "Can you provide information about the swimming pool facilities?",
|
| 102 |
+
"assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
|
| 103 |
+
}
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"conversations": [
|
| 108 |
+
{
|
| 109 |
+
"human": "Please help me cancel my gym appointment for tomorrow.",
|
| 110 |
+
"assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"AppointmentCancellation\", \"operation_id\": \"cancelAppointment\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"appointment_id\": \"12345\"}, \"tool_body_parameters\": {}}"
|
| 111 |
+
}
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"conversations": [
|
| 116 |
+
{
|
| 117 |
+
"human": "I need to reschedule my personal training session.",
|
| 118 |
+
"assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"RescheduleTrainingSession\", \"operation_id\": \"rescheduleSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\"}, \"tool_body_parameters\": {\"new_date\": \"next Tuesday\", \"new_time\": \"11:00 AM\"}}"
|
| 119 |
+
}
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"conversations": [
|
| 124 |
+
{
|
| 125 |
+
"human": "Do you have any yoga classes available in the evening?",
|
| 126 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"ClassScheduleInquiry\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"yoga\", \"time_of_day\": \"evening\"}, \"tool_body_parameters\": {}}"
|
| 127 |
+
}
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"conversations": [
|
| 132 |
+
{
|
| 133 |
+
"human": "I'm interested in the family membership plan. Can you provide more details?",
|
| 134 |
+
"assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
|
| 135 |
+
}
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"conversations": [
|
| 140 |
+
{
|
| 141 |
+
"human": "I'd like to upgrade my membership from silver to gold.",
|
| 142 |
+
"assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"yes\",\"tool_name\": \"MembershipUpgrade\", \"operation_id\": \"upgradeMembership\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"current_membership\": \"silver\", \"new_membership\": \"gold\"}, \"tool_body_parameters\": {}}"
|
| 143 |
+
}
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"conversations": [
|
| 148 |
+
{
|
| 149 |
+
"human": "Is there a way to pause my membership while I travel?",
|
| 150 |
+
"assistant": "{\"conversation_stage_id\": 5,\"tool_required\": \"yes\",\"tool_name\": \"MembershipHold\", \"operation_id\": \"pauseMembership\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"hold_duration\": \"2 months\"}}"
|
| 151 |
+
}
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"conversations": [
|
| 156 |
+
{
|
| 157 |
+
"human": "Can you assist me with changing my contact information?",
|
| 158 |
+
"assistant": "{\"conversation_stage_id\": 5,\"tool_required\": \"yes\",\"tool_name\": \"UpdateContactInfo\", \"operation_id\": \"updateInfo\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"new_phone\": \"1234567890\", \"new_email\": \"[email protected]\"}}"
|
| 159 |
+
}
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"conversations": [
|
| 164 |
+
{
|
| 165 |
+
"human": "What are the payment options available for membership renewal?",
|
| 166 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"no\"}"
|
| 167 |
+
}
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"conversations": [
|
| 172 |
+
{
|
| 173 |
+
"human": "I want to know the personal trainer's availability next week.",
|
| 174 |
+
"assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"TrainerAvailability\", \"operation_id\": \"checkTrainerAvailability\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\", \"date_range\": \"next week\"}, \"tool_body_parameters\": {}}"
|
| 175 |
+
}
|
| 176 |
+
]
|
| 177 |
+
}
|
| 178 |
+
]
|
project_plan.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Finetuner Project Plan
|
| 2 |
+
|
| 3 |
+
## 1. Project Overview
|
| 4 |
+
|
| 5 |
+
The LLM Finetuner is a user-friendly application designed to simplify the process of fine-tuning Large Language Models (LLMs) using the Unsloth library. The application provides a graphical user interface for dataset preparation, model selection, fine-tuning, testing, and GGUF conversion.
|
| 6 |
+
|
| 7 |
+
## 2. Project Structure
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
llm_finetuner/
|
| 11 |
+
├── main.py
|
| 12 |
+
├── ui.py
|
| 13 |
+
├── model_utils.py
|
| 14 |
+
├── dataset_utils.py
|
| 15 |
+
├── training_utils.py
|
| 16 |
+
├── inference_utils.py
|
| 17 |
+
├── gguf_utils.py
|
| 18 |
+
├── requirements.txt
|
| 19 |
+
└── README.md
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## 3. Key Components
|
| 23 |
+
|
| 24 |
+
### 3.1 User Interface (ui.py)
|
| 25 |
+
- Gradio-based interface with tabs for different functionalities
|
| 26 |
+
- Handles user inputs and interactions
|
| 27 |
+
- Coordinates between different modules
|
| 28 |
+
|
| 29 |
+
### 3.2 Model Utilities (model_utils.py)
|
| 30 |
+
- Handles model loading and initialization
|
| 31 |
+
- Supports various pre-trained models from Unsloth
|
| 32 |
+
|
| 33 |
+
### 3.3 Dataset Utilities (dataset_utils.py)
|
| 34 |
+
- Manages dataset preparation from Hugging Face and local files
|
| 35 |
+
- Implements synthetic dataset creation using AI providers (OpenAI, Anthropic, Ollama)
|
| 36 |
+
|
| 37 |
+
### 3.4 Training Utilities (training_utils.py)
|
| 38 |
+
- Implements the fine-tuning process using Unsloth and TRL
|
| 39 |
+
|
| 40 |
+
### 3.5 Inference Utilities (inference_utils.py)
|
| 41 |
+
- Handles model testing and inference
|
| 42 |
+
|
| 43 |
+
### 3.6 GGUF Conversion Utilities (gguf_utils.py)
|
| 44 |
+
- Manages the conversion of fine-tuned models to GGUF format
|
| 45 |
+
|
| 46 |
+
## 4. Implementation Plan
|
| 47 |
+
|
| 48 |
+
### 4.1 Phase 1: Core Functionality
|
| 49 |
+
- [x] Implement basic UI structure
|
| 50 |
+
- [x] Develop model loading and initialization
|
| 51 |
+
- [x] Implement dataset preparation for Hugging Face and local files using the model transformers and chat template.
|
| 52 |
+
- [x] Develop basic fine-tuning functionality using the prepared dataset
|
| 53 |
+
- [x] Implement model testing
|
| 54 |
+
- [x] Add GGUF conversion capability
|
| 55 |
+
|
| 56 |
+
### 4.2 Phase 2: Enhanced Features
|
| 57 |
+
- [x] Implement synthetic dataset creation
|
| 58 |
+
- [ ] Improve error handling and user feedback
|
| 59 |
+
- [ ] Implement progress tracking for long-running operations
|
| 60 |
+
- [ ] Add support for custom model configurations
|
| 61 |
+
|
| 62 |
+
### 4.3 Phase 3: Optimization and Advanced Features
|
| 63 |
+
- [ ] Optimize performance for large datasets and models
|
| 64 |
+
- [ ] Implement advanced fine-tuning techniques (e.g., LoRA, QLoRA)
|
| 65 |
+
- [ ] Add support for distributed training
|
| 66 |
+
- [ ] Implement model comparison tools
|
| 67 |
+
|
| 68 |
+
## 5. Testing Plan
|
| 69 |
+
|
| 70 |
+
### 5.1 Unit Testing
|
| 71 |
+
- Develop unit tests for each utility module
|
| 72 |
+
- Ensure proper error handling and edge case coverage
|
| 73 |
+
|
| 74 |
+
### 5.2 Integration Testing
|
| 75 |
+
- Test the interaction between different modules
|
| 76 |
+
- Verify data flow from UI to backend and vice versa
|
| 77 |
+
|
| 78 |
+
### 5.3 User Acceptance Testing
|
| 79 |
+
- Conduct usability testing with potential users
|
| 80 |
+
- Gather feedback on UI intuitiveness and feature completeness
|
| 81 |
+
|
| 82 |
+
## 6. Deployment Plan
|
| 83 |
+
|
| 84 |
+
### 6.1 Local Deployment
|
| 85 |
+
- Provide clear instructions for local installation and setup
|
| 86 |
+
- Create a comprehensive README with usage guidelines
|
| 87 |
+
|
| 88 |
+
### 6.2 Cloud Deployment (Future Consideration)
|
| 89 |
+
- Explore options for cloud deployment (e.g., Hugging Face Spaces, Google Cloud)
|
| 90 |
+
- Implement necessary security measures for cloud deployment
|
| 91 |
+
|
| 92 |
+
## 7. Documentation
|
| 93 |
+
|
| 94 |
+
- Create user documentation explaining each feature and its usage
|
| 95 |
+
- Develop technical documentation for future maintainers
|
| 96 |
+
- Include examples and use cases in the documentation
|
| 97 |
+
|
| 98 |
+
## 8. Maintenance and Updates
|
| 99 |
+
|
| 100 |
+
- Establish a process for regular updates to supported models and libraries
|
| 101 |
+
- Plan for ongoing bug fixes and feature enhancements based on user feedback
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
This project plan provides a roadmap for the development, testing, and deployment of the LLM Finetuner application. It should be reviewed and updated regularly as the project progresses and new requirements or challenges emerge.
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
unsloth
|
| 4 |
+
datasets
|
| 5 |
+
trl
|
| 6 |
+
transformers
|
| 7 |
+
openai
|
| 8 |
+
anthropic
|
| 9 |
+
requests
|
| 10 |
+
tqdm
|
| 11 |
+
accelerate
|
| 12 |
+
bitsandbytes
|
| 13 |
+
huggingface_hub
|
| 14 |
+
triton
|
| 15 |
+
peft
|
training_utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unsloth import FastLanguageModel
|
| 2 |
+
from trl import SFTTrainer
|
| 3 |
+
from transformers import TrainingArguments, DataCollatorForSeq2Seq
|
| 4 |
+
from unsloth import is_bfloat16_supported
|
| 5 |
+
from unsloth.chat_templates import train_on_responses_only
|
| 6 |
+
|
| 7 |
+
def finetune_model(model, tokenizer, dataset, learning_rate, batch_size, num_epochs):
|
| 8 |
+
"""
|
| 9 |
+
Fine-tune a model on a given dataset, using CUDA if available.
|
| 10 |
+
This version supports fine-tuning of quantized models using PEFT and Unsloth optimizations.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
model: The pre-trained model to fine-tune.
|
| 14 |
+
tokenizer: The tokenizer associated with the model.
|
| 15 |
+
dataset: The dataset to use for fine-tuning.
|
| 16 |
+
learning_rate (float): Learning rate for optimization.
|
| 17 |
+
batch_size (int): Number of training examples used in one iteration.
|
| 18 |
+
num_epochs (int): Number of complete passes through the dataset.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
SFTTrainer: The trained model wrapped in an SFTTrainer object.
|
| 22 |
+
"""
|
| 23 |
+
# Prepare the model for training
|
| 24 |
+
model = FastLanguageModel.get_peft_model(
|
| 25 |
+
model,
|
| 26 |
+
r=16,
|
| 27 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 28 |
+
"gate_proj", "up_proj", "down_proj",],
|
| 29 |
+
lora_alpha=16,
|
| 30 |
+
lora_dropout=0,
|
| 31 |
+
bias="none",
|
| 32 |
+
use_gradient_checkpointing="unsloth",
|
| 33 |
+
random_state=3407,
|
| 34 |
+
use_rslora=False,
|
| 35 |
+
loftq_config=None,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Set up the trainer
|
| 39 |
+
trainer = SFTTrainer(
|
| 40 |
+
model=model,
|
| 41 |
+
tokenizer=tokenizer,
|
| 42 |
+
train_dataset=dataset,
|
| 43 |
+
dataset_text_field="text",
|
| 44 |
+
max_seq_length=model.config.max_position_embeddings,
|
| 45 |
+
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
| 46 |
+
dataset_num_proc=2,
|
| 47 |
+
packing=False,
|
| 48 |
+
args=TrainingArguments(
|
| 49 |
+
per_device_train_batch_size=batch_size,
|
| 50 |
+
gradient_accumulation_steps=4,
|
| 51 |
+
warmup_steps=5,
|
| 52 |
+
num_train_epochs=num_epochs,
|
| 53 |
+
learning_rate=learning_rate,
|
| 54 |
+
fp16=not is_bfloat16_supported(),
|
| 55 |
+
bf16=is_bfloat16_supported(),
|
| 56 |
+
logging_steps=1,
|
| 57 |
+
optim="adamw_8bit",
|
| 58 |
+
weight_decay=0.01,
|
| 59 |
+
lr_scheduler_type="linear",
|
| 60 |
+
seed=3407,
|
| 61 |
+
output_dir="outputs",
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Apply train_on_responses_only
|
| 66 |
+
trainer = train_on_responses_only(
|
| 67 |
+
trainer,
|
| 68 |
+
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
|
| 69 |
+
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Train the model
|
| 73 |
+
trainer.train()
|
| 74 |
+
|
| 75 |
+
return trainer
|
ui.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from model_utils import load_model
|
| 4 |
+
from dataset_utils import prepare_dataset, create_synthetic_dataset
|
| 5 |
+
from training_utils import finetune_model
|
| 6 |
+
from inference_utils import test_model
|
| 7 |
+
from gguf_utils import convert_to_gguf
|
| 8 |
+
from unsloth import FastLanguageModel
|
| 9 |
+
from unsloth.chat_templates import get_chat_template
|
| 10 |
+
from upload_utils import upload_to_huggingface, upload_gguf_to_huggingface
|
| 11 |
+
|
| 12 |
+
def create_gradio_interface():
|
| 13 |
+
models = [
|
| 14 |
+
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
|
| 15 |
+
"unsloth/Mistral-Small-Instruct-2409",
|
| 16 |
+
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
|
| 17 |
+
"unsloth/Phi-3.5-mini-instruct",
|
| 18 |
+
"unsloth/Phi-3-medium-4k-instruct",
|
| 19 |
+
"unsloth/gemma-2-9b-bnb-4bit",
|
| 20 |
+
"unsloth/gemma-2-27b-bnb-4bit",
|
| 21 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
with gr.Blocks() as demo:
|
| 25 |
+
gr.Markdown("# LLM Finetuner")
|
| 26 |
+
|
| 27 |
+
model = gr.State(None)
|
| 28 |
+
tokenizer = gr.State(None)
|
| 29 |
+
dataset = gr.State(None)
|
| 30 |
+
|
| 31 |
+
with gr.Tab("Settings"):
|
| 32 |
+
hf_token = gr.Textbox(label="Hugging Face Token", type="password")
|
| 33 |
+
model_path = gr.Dropdown(label="Model", choices=models, value="unsloth/Llama-3.2-3B-Instruct")
|
| 34 |
+
load_model_btn = gr.Button("Load Model")
|
| 35 |
+
load_model_output = gr.Textbox(label="Load Model Output")
|
| 36 |
+
|
| 37 |
+
with gr.Tab("Dataset"):
|
| 38 |
+
with gr.Group():
|
| 39 |
+
gr.Markdown("## Use Existing Dataset")
|
| 40 |
+
dataset_source = gr.Radio(["Hugging Face", "Local File"], label="Dataset Source", value="Hugging Face")
|
| 41 |
+
hf_dataset_path = gr.Textbox(label="Hugging Face Dataset Path", value="mlabonne/FineTome-100k")
|
| 42 |
+
local_dataset_path = gr.File(label="Upload Local Dataset (JSON or CSV)", visible=False)
|
| 43 |
+
prepare_dataset_btn = gr.Button("Prepare Dataset")
|
| 44 |
+
prepare_dataset_output = gr.Textbox(label="Prepare Dataset Output")
|
| 45 |
+
|
| 46 |
+
with gr.Group():
|
| 47 |
+
gr.Markdown("## Create Synthetic Dataset")
|
| 48 |
+
examples = gr.Textbox(label="Example Conversations", lines=10, placeholder="Enter example conversations here...")
|
| 49 |
+
expected_structure = gr.Textbox(label="Expected Dataset Structure", lines=5, placeholder="Enter the expected structure for the dataset...")
|
| 50 |
+
num_samples = gr.Number(label="Number of Samples to Generate", value=100)
|
| 51 |
+
ai_provider = gr.Radio(["OpenAI", "Anthropic", "Ollama"], label="AI Provider")
|
| 52 |
+
api_key = gr.Textbox(label="API Key", type="password")
|
| 53 |
+
ollama_model = gr.Textbox(label="Ollama Model Name", visible=False)
|
| 54 |
+
create_dataset_btn = gr.Button("Create Synthetic Dataset")
|
| 55 |
+
create_dataset_output = gr.Textbox(label="Create Dataset Output")
|
| 56 |
+
|
| 57 |
+
with gr.Tab("Training"):
|
| 58 |
+
learning_rate = gr.Number(label="Learning Rate", value=2e-4)
|
| 59 |
+
batch_size = gr.Number(label="Batch Size", value=2)
|
| 60 |
+
num_epochs = gr.Number(label="Number of Epochs", value=1)
|
| 61 |
+
train_btn = gr.Button("Start Training")
|
| 62 |
+
train_output = gr.Textbox(label="Training Output")
|
| 63 |
+
|
| 64 |
+
with gr.Tab("Test"):
|
| 65 |
+
test_input = gr.Textbox(label="Test Input")
|
| 66 |
+
test_btn = gr.Button("Test Model")
|
| 67 |
+
test_output = gr.Textbox(label="Model Output")
|
| 68 |
+
|
| 69 |
+
with gr.Tab("GGUF Conversion"):
|
| 70 |
+
gguf_output_path = gr.Textbox(label="GGUF Output Path")
|
| 71 |
+
gguf_quant_method = gr.Dropdown(
|
| 72 |
+
label="Quantization Method",
|
| 73 |
+
choices=["q8_0", "q4_k_m", "q5_k_m", "f16"],
|
| 74 |
+
value="q8_0"
|
| 75 |
+
)
|
| 76 |
+
gguf_convert_btn = gr.Button("Convert to GGUF")
|
| 77 |
+
gguf_output = gr.Textbox(label="GGUF Conversion Output")
|
| 78 |
+
|
| 79 |
+
with gr.Tab("Upload to Hugging Face"):
|
| 80 |
+
repo_name = gr.Textbox(label="Hugging Face Repository Name")
|
| 81 |
+
model_type = gr.Radio(["Fine-tuned Model", "GGUF Converted Model"], label="Model Type to Upload", value="Fine-tuned Model")
|
| 82 |
+
gguf_file_path = gr.Textbox(label="GGUF File Path (if uploading GGUF model)", visible=False)
|
| 83 |
+
upload_btn = gr.Button("Upload to Hugging Face")
|
| 84 |
+
upload_output = gr.Textbox(label="Upload Output")
|
| 85 |
+
|
| 86 |
+
def load_model_and_tokenizer(model_path, hf_token):
|
| 87 |
+
model_val, tokenizer_val = load_model(model_path, hf_token)
|
| 88 |
+
tokenizer_val = get_chat_template(tokenizer_val, chat_template="llama-3.1")
|
| 89 |
+
return model_val, tokenizer_val, "Model and tokenizer loaded successfully!"
|
| 90 |
+
|
| 91 |
+
def update_ollama_visibility(choice):
|
| 92 |
+
return gr.update(visible=(choice == "Ollama"))
|
| 93 |
+
|
| 94 |
+
def update_dataset_input_visibility(choice):
|
| 95 |
+
return gr.update(visible=(choice == "Hugging Face")), gr.update(visible=(choice == "Local File"))
|
| 96 |
+
|
| 97 |
+
def update_gguf_path_visibility(choice):
|
| 98 |
+
return gr.update(visible=(choice == "GGUF Converted Model"))
|
| 99 |
+
|
| 100 |
+
load_model_btn.click(
|
| 101 |
+
load_model_and_tokenizer,
|
| 102 |
+
inputs=[model_path, hf_token],
|
| 103 |
+
outputs=[model, tokenizer, load_model_output]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
dataset_source.change(
|
| 107 |
+
update_dataset_input_visibility,
|
| 108 |
+
inputs=[dataset_source],
|
| 109 |
+
outputs=[hf_dataset_path, local_dataset_path]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
model_type.change(
|
| 113 |
+
update_gguf_path_visibility,
|
| 114 |
+
inputs=[model_type],
|
| 115 |
+
outputs=[gguf_file_path]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def prepare_dataset_wrapper(source, hf_path, local_file, hf_token, tokenizer_val):
|
| 119 |
+
if tokenizer_val is None:
|
| 120 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 121 |
+
|
| 122 |
+
if source == "Hugging Face":
|
| 123 |
+
dataset_val = prepare_dataset("huggingface", hf_path, tokenizer_val, hf_token)
|
| 124 |
+
elif source == "Local File":
|
| 125 |
+
if local_file is not None:
|
| 126 |
+
dataset_val = prepare_dataset("local", local_file.name, tokenizer_val)
|
| 127 |
+
else:
|
| 128 |
+
return "No file uploaded. Please upload a local dataset file."
|
| 129 |
+
else:
|
| 130 |
+
return "Invalid dataset source selected."
|
| 131 |
+
|
| 132 |
+
return dataset_val, "Dataset prepared successfully!"
|
| 133 |
+
|
| 134 |
+
prepare_dataset_btn.click(
|
| 135 |
+
prepare_dataset_wrapper,
|
| 136 |
+
inputs=[dataset_source, hf_dataset_path, local_dataset_path, hf_token, tokenizer],
|
| 137 |
+
outputs=[dataset, prepare_dataset_output]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def create_synthetic_dataset_wrapper(examples, expected_structure, num_samples, ai_provider, api_key, ollama_model, tokenizer_val):
|
| 141 |
+
if tokenizer_val is None:
|
| 142 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 143 |
+
|
| 144 |
+
dataset_val = create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, ollama_model)
|
| 145 |
+
return dataset_val, "Synthetic dataset created successfully!"
|
| 146 |
+
|
| 147 |
+
create_dataset_btn.click(
|
| 148 |
+
create_synthetic_dataset_wrapper,
|
| 149 |
+
inputs=[examples, expected_structure, num_samples, ai_provider, api_key, ollama_model, tokenizer],
|
| 150 |
+
outputs=[dataset, create_dataset_output]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
ai_provider.change(update_ollama_visibility, inputs=[ai_provider], outputs=[ollama_model])
|
| 154 |
+
|
| 155 |
+
def train_model_wrapper(model_val, tokenizer_val, dataset_val, learning_rate, batch_size, num_epochs):
|
| 156 |
+
if model_val is None or tokenizer_val is None:
|
| 157 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 158 |
+
if dataset_val is None:
|
| 159 |
+
return "Error: Dataset not prepared. Please prepare or create a dataset first."
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
trainer = finetune_model(model_val, tokenizer_val, dataset_val, learning_rate, batch_size, num_epochs)
|
| 163 |
+
return "Training completed successfully!"
|
| 164 |
+
except Exception as e:
|
| 165 |
+
return f"Error during training: {str(e)}"
|
| 166 |
+
|
| 167 |
+
train_btn.click(
|
| 168 |
+
train_model_wrapper,
|
| 169 |
+
inputs=[model, tokenizer, dataset, learning_rate, batch_size, num_epochs],
|
| 170 |
+
outputs=[train_output]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def test_model_wrapper(model_val, tokenizer_val, test_input):
|
| 174 |
+
if model_val is None or tokenizer_val is None:
|
| 175 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 176 |
+
|
| 177 |
+
FastLanguageModel.for_inference(model_val) # Enable native 2x faster inference
|
| 178 |
+
messages = [{"role": "user", "content": test_input}]
|
| 179 |
+
inputs = tokenizer_val.apply_chat_template(
|
| 180 |
+
messages,
|
| 181 |
+
tokenize=True,
|
| 182 |
+
add_generation_prompt=True,
|
| 183 |
+
return_tensors="pt"
|
| 184 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 185 |
+
|
| 186 |
+
outputs = model_val.generate(input_ids=inputs, max_new_tokens=128, temperature=1.5, min_p=0.1)
|
| 187 |
+
return tokenizer_val.batch_decode(outputs)[0]
|
| 188 |
+
|
| 189 |
+
test_btn.click(
|
| 190 |
+
test_model_wrapper,
|
| 191 |
+
inputs=[model, tokenizer, test_input],
|
| 192 |
+
outputs=[test_output]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def convert_to_gguf_wrapper(model_val, tokenizer_val, gguf_output_path, gguf_quant_method):
|
| 196 |
+
if model_val is None or tokenizer_val is None:
|
| 197 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 198 |
+
|
| 199 |
+
output = convert_to_gguf(model_val, tokenizer_val, gguf_output_path, gguf_quant_method)
|
| 200 |
+
return output
|
| 201 |
+
|
| 202 |
+
gguf_convert_btn.click(
|
| 203 |
+
convert_to_gguf_wrapper,
|
| 204 |
+
inputs=[model, tokenizer, gguf_output_path, gguf_quant_method],
|
| 205 |
+
outputs=[gguf_output]
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def upload_to_hf_wrapper(model_val, tokenizer_val, repo_name, hf_token, model_type, gguf_file_path):
|
| 209 |
+
if model_type == "Fine-tuned Model":
|
| 210 |
+
if model_val is None or tokenizer_val is None:
|
| 211 |
+
return "Error: Model and tokenizer not loaded. Please load the model first."
|
| 212 |
+
result = upload_to_huggingface(model_val, tokenizer_val, repo_name, hf_token)
|
| 213 |
+
elif model_type == "GGUF Converted Model":
|
| 214 |
+
if not gguf_file_path:
|
| 215 |
+
return "Error: GGUF file path not provided. Please enter the path to the GGUF file."
|
| 216 |
+
result = upload_gguf_to_huggingface(gguf_file_path, repo_name, hf_token)
|
| 217 |
+
else:
|
| 218 |
+
return "Error: Invalid model type selected."
|
| 219 |
+
return result
|
| 220 |
+
|
| 221 |
+
upload_btn.click(
|
| 222 |
+
upload_to_hf_wrapper,
|
| 223 |
+
inputs=[model, tokenizer, repo_name, hf_token, model_type, gguf_file_path],
|
| 224 |
+
outputs=[upload_output]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return demo
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
demo = create_gradio_interface()
|
| 231 |
+
demo.launch()
|
unsloth_compiled_cache/UnslothAlignPropTrainer.py
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothAlignPropConfig(AlignPropConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`AlignPropTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
| 54 |
+
Name of this experiment (defaults to the file name without the extension).
|
| 55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
| 56 |
+
Name of this run.
|
| 57 |
+
seed (`int`, *optional*, defaults to `0`):
|
| 58 |
+
Random seed for reproducibility.
|
| 59 |
+
log_with (`str` or `None`, *optional*, defaults to `None`):
|
| 60 |
+
Log with either `"wandb"` or `"tensorboard"`. Check
|
| 61 |
+
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
| 62 |
+
log_image_freq (`int`, *optional*, defaults to `1`):
|
| 63 |
+
Frequency for logging images.
|
| 64 |
+
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 65 |
+
Keyword arguments for the tracker (e.g., `wandb_project`).
|
| 66 |
+
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 67 |
+
Keyword arguments for the accelerator.
|
| 68 |
+
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 69 |
+
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
| 70 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
| 71 |
+
Name of project to use for tracking.
|
| 72 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
| 73 |
+
Top-level logging directory for checkpoint saving.
|
| 74 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
| 75 |
+
Number of epochs to train.
|
| 76 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
| 77 |
+
Number of epochs between saving model checkpoints.
|
| 78 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
| 79 |
+
Number of checkpoints to keep before overwriting old ones.
|
| 80 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
| 81 |
+
Mixed precision training.
|
| 82 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Allow `tf32` on Ampere GPUs.
|
| 84 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
| 85 |
+
Path to resume training from a checkpoint.
|
| 86 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
| 87 |
+
Number of sampler inference steps.
|
| 88 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
| 89 |
+
Eta parameter for the DDIM sampler.
|
| 90 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 91 |
+
Classifier-free guidance weight.
|
| 92 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
| 93 |
+
Batch size for training.
|
| 94 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
| 96 |
+
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
| 97 |
+
Learning rate.
|
| 98 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
| 99 |
+
Beta1 for Adam optimizer.
|
| 100 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
| 101 |
+
Beta2 for Adam optimizer.
|
| 102 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
| 103 |
+
Weight decay for Adam optimizer.
|
| 104 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
| 105 |
+
Epsilon value for Adam optimizer.
|
| 106 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
| 107 |
+
Number of gradient accumulation steps.
|
| 108 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
| 109 |
+
Maximum gradient norm for gradient clipping.
|
| 110 |
+
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
| 111 |
+
Comma-separated list of prompts to use as negative examples.
|
| 112 |
+
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
If `True`, randomized truncation to different diffusion timesteps is used.
|
| 114 |
+
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
| 115 |
+
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
| 116 |
+
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
| 117 |
+
Range of diffusion timesteps for randomized truncated backpropagation.
|
| 118 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 119 |
+
Whether to push the final model to the Hub.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 123 |
+
default = None,
|
| 124 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 125 |
+
)
|
| 126 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 127 |
+
default = -1,
|
| 128 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 129 |
+
)
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
exp_name = 'main',
|
| 133 |
+
run_name = '',
|
| 134 |
+
seed = 3407,
|
| 135 |
+
log_with = None,
|
| 136 |
+
log_image_freq = 1,
|
| 137 |
+
tracker_project_name = 'trl',
|
| 138 |
+
logdir = 'logs',
|
| 139 |
+
num_epochs = 100,
|
| 140 |
+
save_freq = 1,
|
| 141 |
+
num_checkpoint_limit = 5,
|
| 142 |
+
mixed_precision = 'fp16',
|
| 143 |
+
allow_tf32 = True,
|
| 144 |
+
resume_from = '',
|
| 145 |
+
sample_num_steps = 50,
|
| 146 |
+
sample_eta = 1.0,
|
| 147 |
+
sample_guidance_scale = 5.0,
|
| 148 |
+
train_batch_size = 1,
|
| 149 |
+
train_use_8bit_adam = False,
|
| 150 |
+
train_learning_rate = 5e-05,
|
| 151 |
+
train_adam_beta1 = 0.9,
|
| 152 |
+
train_adam_beta2 = 0.999,
|
| 153 |
+
train_adam_weight_decay = 0.01,
|
| 154 |
+
train_adam_epsilon = 1e-08,
|
| 155 |
+
train_gradient_accumulation_steps = 2,
|
| 156 |
+
train_max_grad_norm = 1.0,
|
| 157 |
+
negative_prompts = None,
|
| 158 |
+
truncated_backprop_rand = True,
|
| 159 |
+
truncated_backprop_timestep = 49,
|
| 160 |
+
push_to_hub = False,
|
| 161 |
+
vllm_sampling_params = None,
|
| 162 |
+
unsloth_num_chunks = -1,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
|
| 166 |
+
super().__init__(
|
| 167 |
+
exp_name = exp_name,
|
| 168 |
+
run_name = run_name,
|
| 169 |
+
seed = seed,
|
| 170 |
+
log_with = log_with,
|
| 171 |
+
log_image_freq = log_image_freq,
|
| 172 |
+
tracker_project_name = tracker_project_name,
|
| 173 |
+
logdir = logdir,
|
| 174 |
+
num_epochs = num_epochs,
|
| 175 |
+
save_freq = save_freq,
|
| 176 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
| 177 |
+
mixed_precision = mixed_precision,
|
| 178 |
+
allow_tf32 = allow_tf32,
|
| 179 |
+
resume_from = resume_from,
|
| 180 |
+
sample_num_steps = sample_num_steps,
|
| 181 |
+
sample_eta = sample_eta,
|
| 182 |
+
sample_guidance_scale = sample_guidance_scale,
|
| 183 |
+
train_batch_size = train_batch_size,
|
| 184 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
| 185 |
+
train_learning_rate = train_learning_rate,
|
| 186 |
+
train_adam_beta1 = train_adam_beta1,
|
| 187 |
+
train_adam_beta2 = train_adam_beta2,
|
| 188 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
| 189 |
+
train_adam_epsilon = train_adam_epsilon,
|
| 190 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
| 191 |
+
train_max_grad_norm = train_max_grad_norm,
|
| 192 |
+
negative_prompts = negative_prompts,
|
| 193 |
+
truncated_backprop_rand = truncated_backprop_rand,
|
| 194 |
+
truncated_backprop_timestep = truncated_backprop_timestep,
|
| 195 |
+
push_to_hub = push_to_hub,**kwargs)
|
| 196 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 197 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
| 201 |
+
""""""
|
| 202 |
+
|
| 203 |
+
_tag_names = ["trl", "alignprop"]
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
config: AlignPropConfig,
|
| 208 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
| 209 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
| 210 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
| 211 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
| 212 |
+
):
|
| 213 |
+
if image_samples_hook is None:
|
| 214 |
+
warn("No image_samples_hook provided; no images will be logged")
|
| 215 |
+
|
| 216 |
+
self.prompt_fn = prompt_function
|
| 217 |
+
self.reward_fn = reward_function
|
| 218 |
+
self.config = config
|
| 219 |
+
self.image_samples_callback = image_samples_hook
|
| 220 |
+
|
| 221 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
| 222 |
+
|
| 223 |
+
if self.config.resume_from:
|
| 224 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
| 225 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
| 226 |
+
# get the most recent checkpoint in this directory
|
| 227 |
+
checkpoints = list(
|
| 228 |
+
filter(
|
| 229 |
+
lambda x: "checkpoint_" in x,
|
| 230 |
+
os.listdir(self.config.resume_from),
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
if len(checkpoints) == 0:
|
| 234 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
| 235 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
| 236 |
+
self.config.resume_from = os.path.join(
|
| 237 |
+
self.config.resume_from,
|
| 238 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
| 242 |
+
|
| 243 |
+
self.accelerator = Accelerator(
|
| 244 |
+
log_with=self.config.log_with,
|
| 245 |
+
mixed_precision=self.config.mixed_precision,
|
| 246 |
+
project_config=accelerator_project_config,
|
| 247 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
| 248 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
| 249 |
+
# the total number of optimizer steps to accumulate across.
|
| 250 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
| 251 |
+
**self.config.accelerator_kwargs,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
| 255 |
+
|
| 256 |
+
if self.accelerator.is_main_process:
|
| 257 |
+
self.accelerator.init_trackers(
|
| 258 |
+
self.config.tracker_project_name,
|
| 259 |
+
config=dict(alignprop_trainer_config=config.to_dict())
|
| 260 |
+
if not is_using_tensorboard
|
| 261 |
+
else config.to_dict(),
|
| 262 |
+
init_kwargs=self.config.tracker_kwargs,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
logger.info(f"\n{config}")
|
| 266 |
+
|
| 267 |
+
set_seed(self.config.seed, device_specific=True)
|
| 268 |
+
|
| 269 |
+
self.sd_pipeline = sd_pipeline
|
| 270 |
+
|
| 271 |
+
self.sd_pipeline.set_progress_bar_config(
|
| 272 |
+
position=1,
|
| 273 |
+
disable=not self.accelerator.is_local_main_process,
|
| 274 |
+
leave=False,
|
| 275 |
+
desc="Timestep",
|
| 276 |
+
dynamic_ncols=True,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 280 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 281 |
+
if self.accelerator.mixed_precision == "fp16":
|
| 282 |
+
inference_dtype = torch.float16
|
| 283 |
+
elif self.accelerator.mixed_precision == "bf16":
|
| 284 |
+
inference_dtype = torch.bfloat16
|
| 285 |
+
else:
|
| 286 |
+
inference_dtype = torch.float32
|
| 287 |
+
|
| 288 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
| 289 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
| 290 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
| 291 |
+
|
| 292 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
| 293 |
+
|
| 294 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
| 295 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
| 296 |
+
|
| 297 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 298 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 299 |
+
if self.config.allow_tf32:
|
| 300 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 301 |
+
|
| 302 |
+
self.optimizer = self._setup_optimizer(
|
| 303 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
| 307 |
+
self.sd_pipeline.tokenizer(
|
| 308 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
padding="max_length",
|
| 311 |
+
truncation=True,
|
| 312 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 313 |
+
).input_ids.to(self.accelerator.device)
|
| 314 |
+
)[0]
|
| 315 |
+
|
| 316 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
| 317 |
+
# more memory
|
| 318 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
| 319 |
+
|
| 320 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
| 321 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 322 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 323 |
+
else:
|
| 324 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 325 |
+
|
| 326 |
+
if config.resume_from:
|
| 327 |
+
logger.info(f"Resuming from {config.resume_from}")
|
| 328 |
+
self.accelerator.load_state(config.resume_from)
|
| 329 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
| 330 |
+
else:
|
| 331 |
+
self.first_epoch = 0
|
| 332 |
+
|
| 333 |
+
def compute_rewards(self, prompt_image_pairs):
|
| 334 |
+
reward, reward_metadata = self.reward_fn(
|
| 335 |
+
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
| 336 |
+
)
|
| 337 |
+
return reward
|
| 338 |
+
|
| 339 |
+
def step(self, epoch: int, global_step: int):
|
| 340 |
+
"""
|
| 341 |
+
Perform a single step of training.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
epoch (int): The current epoch.
|
| 345 |
+
global_step (int): The current global step.
|
| 346 |
+
|
| 347 |
+
Side Effects:
|
| 348 |
+
- Model weights are updated
|
| 349 |
+
- Logs the statistics to the accelerator trackers.
|
| 350 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
global_step (int): The updated global step.
|
| 354 |
+
"""
|
| 355 |
+
info = defaultdict(list)
|
| 356 |
+
|
| 357 |
+
self.sd_pipeline.unet.train()
|
| 358 |
+
|
| 359 |
+
for _ in range(self.config.train_gradient_accumulation_steps):
|
| 360 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
| 361 |
+
prompt_image_pairs = self._generate_samples(
|
| 362 |
+
batch_size=self.config.train_batch_size,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
rewards = self.compute_rewards(prompt_image_pairs)
|
| 366 |
+
|
| 367 |
+
prompt_image_pairs["rewards"] = rewards
|
| 368 |
+
|
| 369 |
+
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
| 370 |
+
|
| 371 |
+
loss = self.calculate_loss(rewards)
|
| 372 |
+
|
| 373 |
+
self.accelerator.backward(loss)
|
| 374 |
+
|
| 375 |
+
if self.accelerator.sync_gradients:
|
| 376 |
+
self.accelerator.clip_grad_norm_(
|
| 377 |
+
self.trainable_layers.parameters()
|
| 378 |
+
if not isinstance(self.trainable_layers, list)
|
| 379 |
+
else self.trainable_layers,
|
| 380 |
+
self.config.train_max_grad_norm,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.optimizer.step()
|
| 384 |
+
self.optimizer.zero_grad()
|
| 385 |
+
|
| 386 |
+
info["reward_mean"].append(rewards_vis.mean())
|
| 387 |
+
info["reward_std"].append(rewards_vis.std())
|
| 388 |
+
info["loss"].append(loss.item())
|
| 389 |
+
|
| 390 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 391 |
+
if self.accelerator.sync_gradients:
|
| 392 |
+
# log training-related stuff
|
| 393 |
+
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
| 394 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
| 395 |
+
info.update({"epoch": epoch})
|
| 396 |
+
self.accelerator.log(info, step=global_step)
|
| 397 |
+
global_step += 1
|
| 398 |
+
info = defaultdict(list)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(
|
| 401 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
| 402 |
+
)
|
| 403 |
+
# Logs generated images
|
| 404 |
+
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
| 405 |
+
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
| 406 |
+
|
| 407 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
| 408 |
+
self.accelerator.save_state()
|
| 409 |
+
|
| 410 |
+
return global_step
|
| 411 |
+
|
| 412 |
+
def calculate_loss(self, rewards):
|
| 413 |
+
"""
|
| 414 |
+
Calculate the loss for a batch of an unpacked sample
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
rewards (torch.Tensor):
|
| 418 |
+
Differentiable reward scalars for each generated image, shape: [batch_size]
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
loss (torch.Tensor)
|
| 422 |
+
(all of these are of shape (1,))
|
| 423 |
+
"""
|
| 424 |
+
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
| 425 |
+
loss = 10.0 - (rewards).mean()
|
| 426 |
+
return loss
|
| 427 |
+
|
| 428 |
+
def loss(
|
| 429 |
+
self,
|
| 430 |
+
advantages: torch.Tensor,
|
| 431 |
+
clip_range: float,
|
| 432 |
+
ratio: torch.Tensor,
|
| 433 |
+
):
|
| 434 |
+
unclipped_loss = -advantages * ratio
|
| 435 |
+
clipped_loss = -advantages * torch.clamp(
|
| 436 |
+
ratio,
|
| 437 |
+
1.0 - clip_range,
|
| 438 |
+
1.0 + clip_range,
|
| 439 |
+
)
|
| 440 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
| 441 |
+
|
| 442 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
| 443 |
+
if self.config.train_use_8bit_adam:
|
| 444 |
+
import bitsandbytes
|
| 445 |
+
|
| 446 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
| 447 |
+
else:
|
| 448 |
+
optimizer_cls = torch.optim.AdamW
|
| 449 |
+
|
| 450 |
+
return optimizer_cls(
|
| 451 |
+
trainable_layers_parameters,
|
| 452 |
+
lr=self.config.train_learning_rate,
|
| 453 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
| 454 |
+
weight_decay=self.config.train_adam_weight_decay,
|
| 455 |
+
eps=self.config.train_adam_epsilon,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def _save_model_hook(self, models, weights, output_dir):
|
| 459 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
| 460 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
| 461 |
+
|
| 462 |
+
def _load_model_hook(self, models, input_dir):
|
| 463 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
| 464 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
| 465 |
+
|
| 466 |
+
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
| 467 |
+
"""
|
| 468 |
+
Generate samples from the model
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
batch_size (int): Batch size to use for sampling
|
| 472 |
+
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
| 473 |
+
|
| 474 |
+
Returns:
|
| 475 |
+
prompt_image_pairs (dict[Any])
|
| 476 |
+
"""
|
| 477 |
+
prompt_image_pairs = {}
|
| 478 |
+
|
| 479 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
| 480 |
+
|
| 481 |
+
if prompts is None:
|
| 482 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
| 483 |
+
else:
|
| 484 |
+
prompt_metadata = [{} for _ in range(batch_size)]
|
| 485 |
+
|
| 486 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
| 487 |
+
prompts,
|
| 488 |
+
return_tensors="pt",
|
| 489 |
+
padding="max_length",
|
| 490 |
+
truncation=True,
|
| 491 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 492 |
+
).input_ids.to(self.accelerator.device)
|
| 493 |
+
|
| 494 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
| 495 |
+
|
| 496 |
+
if with_grad:
|
| 497 |
+
sd_output = self.sd_pipeline.rgb_with_grad(
|
| 498 |
+
prompt_embeds=prompt_embeds,
|
| 499 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 500 |
+
num_inference_steps=self.config.sample_num_steps,
|
| 501 |
+
guidance_scale=self.config.sample_guidance_scale,
|
| 502 |
+
eta=self.config.sample_eta,
|
| 503 |
+
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
| 504 |
+
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
| 505 |
+
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
| 506 |
+
output_type="pt",
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
sd_output = self.sd_pipeline(
|
| 510 |
+
prompt_embeds=prompt_embeds,
|
| 511 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 512 |
+
num_inference_steps=self.config.sample_num_steps,
|
| 513 |
+
guidance_scale=self.config.sample_guidance_scale,
|
| 514 |
+
eta=self.config.sample_eta,
|
| 515 |
+
output_type="pt",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
images = sd_output.images
|
| 519 |
+
|
| 520 |
+
prompt_image_pairs["images"] = images
|
| 521 |
+
prompt_image_pairs["prompts"] = prompts
|
| 522 |
+
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
| 523 |
+
|
| 524 |
+
return prompt_image_pairs
|
| 525 |
+
|
| 526 |
+
def train(self, epochs: Optional[int] = None):
|
| 527 |
+
"""
|
| 528 |
+
Train the model for a given number of epochs
|
| 529 |
+
"""
|
| 530 |
+
global_step = 0
|
| 531 |
+
if epochs is None:
|
| 532 |
+
epochs = self.config.num_epochs
|
| 533 |
+
for epoch in range(self.first_epoch, epochs):
|
| 534 |
+
global_step = self.step(epoch, global_step)
|
| 535 |
+
|
| 536 |
+
def _save_pretrained(self, save_directory):
|
| 537 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
| 538 |
+
self.create_model_card()
|
| 539 |
+
|
| 540 |
+
def create_model_card(
|
| 541 |
+
self,
|
| 542 |
+
model_name: Optional[str] = None,
|
| 543 |
+
dataset_name: Optional[str] = None,
|
| 544 |
+
tags: Union[str, list[str], None] = None,
|
| 545 |
+
):
|
| 546 |
+
"""
|
| 547 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 551 |
+
Name of the model.
|
| 552 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 553 |
+
Name of the dataset used for training.
|
| 554 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 555 |
+
Tags to be associated with the model card.
|
| 556 |
+
"""
|
| 557 |
+
if not self.is_world_process_zero():
|
| 558 |
+
return
|
| 559 |
+
|
| 560 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 561 |
+
base_model = self.model.config._name_or_path
|
| 562 |
+
else:
|
| 563 |
+
base_model = None
|
| 564 |
+
|
| 565 |
+
tags = tags or []
|
| 566 |
+
if isinstance(tags, str):
|
| 567 |
+
tags = [tags]
|
| 568 |
+
|
| 569 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 570 |
+
tags.append("unsloth")
|
| 571 |
+
|
| 572 |
+
citation = textwrap.dedent("""\
|
| 573 |
+
@article{prabhudesai2024aligning,
|
| 574 |
+
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
| 575 |
+
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
| 576 |
+
year = 2024,
|
| 577 |
+
eprint = {arXiv:2310.03739}
|
| 578 |
+
}""")
|
| 579 |
+
|
| 580 |
+
model_card = generate_model_card(
|
| 581 |
+
base_model=base_model,
|
| 582 |
+
model_name=model_name,
|
| 583 |
+
hub_model_id=self.hub_model_id,
|
| 584 |
+
dataset_name=dataset_name,
|
| 585 |
+
tags=tags,
|
| 586 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 587 |
+
comet_url=get_comet_experiment_url(),
|
| 588 |
+
trainer_name="AlignProp",
|
| 589 |
+
trainer_citation=citation,
|
| 590 |
+
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
| 591 |
+
paper_id="2310.03739",
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 595 |
+
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
| 596 |
+
"""
|
| 597 |
+
|
| 598 |
+
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
| 599 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
| 600 |
+
As of now only Stable Diffusion based pipelines are supported
|
| 601 |
+
|
| 602 |
+
Attributes:
|
| 603 |
+
config (`AlignPropConfig`):
|
| 604 |
+
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
| 605 |
+
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
| 606 |
+
Reward function to be used
|
| 607 |
+
prompt_function (`Callable[[], tuple[str, Any]]`):
|
| 608 |
+
Function to generate prompts to guide model
|
| 609 |
+
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
| 610 |
+
Stable Diffusion pipeline to be used for training.
|
| 611 |
+
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
| 612 |
+
Hook to be called to log images
|
| 613 |
+
|
| 614 |
+
"""
|
| 615 |
+
def __init__(
|
| 616 |
+
self,
|
| 617 |
+
config,
|
| 618 |
+
reward_function,
|
| 619 |
+
prompt_function,
|
| 620 |
+
sd_pipeline,
|
| 621 |
+
image_samples_hook = None,
|
| 622 |
+
**kwargs
|
| 623 |
+
):
|
| 624 |
+
if args is None: args = UnslothAlignPropConfig()
|
| 625 |
+
other_metrics = []
|
| 626 |
+
|
| 627 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 628 |
+
PatchRLStatistics('alignprop_trainer', other_metrics)
|
| 629 |
+
|
| 630 |
+
super().__init__(
|
| 631 |
+
config = config,
|
| 632 |
+
reward_function = reward_function,
|
| 633 |
+
prompt_function = prompt_function,
|
| 634 |
+
sd_pipeline = sd_pipeline,
|
| 635 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
| 636 |
+
|
| 637 |
+
pass
|
unsloth_compiled_cache/UnslothBCOTrainer.py
ADDED
|
@@ -0,0 +1,1824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothBCOConfig(BCOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`BCOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 54 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 55 |
+
to use the default data collator.
|
| 56 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 57 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 58 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 59 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 60 |
+
and your model is an encoder-decoder.
|
| 61 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 62 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 63 |
+
reference model.
|
| 64 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 65 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 66 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 67 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 68 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 69 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 70 |
+
This argument is required if you want to use the default data collator.
|
| 71 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether to disable dropout in the model and reference model.
|
| 73 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 74 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
| 75 |
+
evaluation.
|
| 76 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 77 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 78 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 79 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 80 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 81 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
| 82 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 83 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 84 |
+
string.
|
| 85 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 86 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
| 87 |
+
from a string.
|
| 88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 89 |
+
Number of processes to use for processing the dataset.
|
| 90 |
+
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
| 91 |
+
Number of prompts that are fed to density ratio classifier.
|
| 92 |
+
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
| 93 |
+
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 94 |
+
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
| 95 |
+
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 99 |
+
default = None,
|
| 100 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 101 |
+
)
|
| 102 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 103 |
+
default = -1,
|
| 104 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 105 |
+
)
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
output_dir = None,
|
| 109 |
+
overwrite_output_dir = None,
|
| 110 |
+
do_train = False,
|
| 111 |
+
do_eval = False,
|
| 112 |
+
do_predict = False,
|
| 113 |
+
eval_strategy = 'no',
|
| 114 |
+
prediction_loss_only = False,
|
| 115 |
+
per_device_train_batch_size = 4,
|
| 116 |
+
per_device_eval_batch_size = 4,
|
| 117 |
+
per_gpu_train_batch_size = None,
|
| 118 |
+
per_gpu_eval_batch_size = None,
|
| 119 |
+
gradient_accumulation_steps = 2,
|
| 120 |
+
eval_accumulation_steps = 2,
|
| 121 |
+
eval_delay = 0,
|
| 122 |
+
torch_empty_cache_steps = 250,
|
| 123 |
+
learning_rate = 5e-05,
|
| 124 |
+
weight_decay = 0.01,
|
| 125 |
+
adam_beta1 = 0.9,
|
| 126 |
+
adam_beta2 = 0.999,
|
| 127 |
+
adam_epsilon = 1e-08,
|
| 128 |
+
max_grad_norm = 1.0,
|
| 129 |
+
num_train_epochs = 3.0,
|
| 130 |
+
max_steps = -1,
|
| 131 |
+
lr_scheduler_type = 'linear',
|
| 132 |
+
warmup_ratio = 0.1,
|
| 133 |
+
warmup_steps = 0,
|
| 134 |
+
log_level = 'passive',
|
| 135 |
+
log_level_replica = 'warning',
|
| 136 |
+
log_on_each_node = True,
|
| 137 |
+
logging_dir = None,
|
| 138 |
+
logging_strategy = 'steps',
|
| 139 |
+
logging_first_step = False,
|
| 140 |
+
logging_steps = 1,
|
| 141 |
+
logging_nan_inf_filter = False,
|
| 142 |
+
save_strategy = 'steps',
|
| 143 |
+
save_steps = 500,
|
| 144 |
+
save_total_limit = None,
|
| 145 |
+
save_safetensors = True,
|
| 146 |
+
save_on_each_node = False,
|
| 147 |
+
save_only_model = False,
|
| 148 |
+
restore_callback_states_from_checkpoint = False,
|
| 149 |
+
no_cuda = False,
|
| 150 |
+
use_cpu = False,
|
| 151 |
+
use_mps_device = False,
|
| 152 |
+
seed = 3407,
|
| 153 |
+
data_seed = 3407,
|
| 154 |
+
jit_mode_eval = False,
|
| 155 |
+
use_ipex = False,
|
| 156 |
+
bf16 = False,
|
| 157 |
+
fp16 = False,
|
| 158 |
+
fp16_opt_level = 'O1',
|
| 159 |
+
half_precision_backend = 'auto',
|
| 160 |
+
bf16_full_eval = False,
|
| 161 |
+
fp16_full_eval = False,
|
| 162 |
+
tf32 = None,
|
| 163 |
+
local_rank = -1,
|
| 164 |
+
ddp_backend = None,
|
| 165 |
+
tpu_num_cores = None,
|
| 166 |
+
tpu_metrics_debug = False,
|
| 167 |
+
debug = '',
|
| 168 |
+
dataloader_drop_last = False,
|
| 169 |
+
eval_steps = None,
|
| 170 |
+
dataloader_num_workers = 0,
|
| 171 |
+
dataloader_prefetch_factor = None,
|
| 172 |
+
past_index = -1,
|
| 173 |
+
run_name = None,
|
| 174 |
+
disable_tqdm = None,
|
| 175 |
+
remove_unused_columns = True,
|
| 176 |
+
label_names = None,
|
| 177 |
+
load_best_model_at_end = False,
|
| 178 |
+
metric_for_best_model = None,
|
| 179 |
+
greater_is_better = None,
|
| 180 |
+
ignore_data_skip = False,
|
| 181 |
+
fsdp = '',
|
| 182 |
+
fsdp_min_num_params = 0,
|
| 183 |
+
fsdp_config = None,
|
| 184 |
+
tp_size = 0,
|
| 185 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 186 |
+
accelerator_config = None,
|
| 187 |
+
deepspeed = None,
|
| 188 |
+
label_smoothing_factor = 0.0,
|
| 189 |
+
optim = 'adamw_8bit',
|
| 190 |
+
optim_args = None,
|
| 191 |
+
adafactor = False,
|
| 192 |
+
group_by_length = False,
|
| 193 |
+
length_column_name = 'length',
|
| 194 |
+
report_to = None,
|
| 195 |
+
ddp_find_unused_parameters = None,
|
| 196 |
+
ddp_bucket_cap_mb = None,
|
| 197 |
+
ddp_broadcast_buffers = None,
|
| 198 |
+
dataloader_pin_memory = True,
|
| 199 |
+
dataloader_persistent_workers = False,
|
| 200 |
+
skip_memory_metrics = True,
|
| 201 |
+
use_legacy_prediction_loop = False,
|
| 202 |
+
push_to_hub = False,
|
| 203 |
+
resume_from_checkpoint = None,
|
| 204 |
+
hub_model_id = None,
|
| 205 |
+
hub_strategy = 'every_save',
|
| 206 |
+
hub_token = None,
|
| 207 |
+
hub_private_repo = None,
|
| 208 |
+
hub_always_push = False,
|
| 209 |
+
gradient_checkpointing = False,
|
| 210 |
+
gradient_checkpointing_kwargs = None,
|
| 211 |
+
include_inputs_for_metrics = False,
|
| 212 |
+
eval_do_concat_batches = True,
|
| 213 |
+
fp16_backend = 'auto',
|
| 214 |
+
evaluation_strategy = None,
|
| 215 |
+
push_to_hub_model_id = None,
|
| 216 |
+
push_to_hub_organization = None,
|
| 217 |
+
push_to_hub_token = None,
|
| 218 |
+
mp_parameters = '',
|
| 219 |
+
auto_find_batch_size = False,
|
| 220 |
+
full_determinism = False,
|
| 221 |
+
torchdynamo = None,
|
| 222 |
+
ray_scope = 'last',
|
| 223 |
+
ddp_timeout = 1800,
|
| 224 |
+
torch_compile = False,
|
| 225 |
+
torch_compile_backend = None,
|
| 226 |
+
torch_compile_mode = None,
|
| 227 |
+
dispatch_batches = None,
|
| 228 |
+
split_batches = None,
|
| 229 |
+
include_tokens_per_second = False,
|
| 230 |
+
include_num_input_tokens_seen = False,
|
| 231 |
+
neftune_noise_alpha = None,
|
| 232 |
+
optim_target_modules = None,
|
| 233 |
+
batch_eval_metrics = False,
|
| 234 |
+
eval_on_start = False,
|
| 235 |
+
use_liger_kernel = False,
|
| 236 |
+
eval_use_gather_object = False,
|
| 237 |
+
average_tokens_across_devices = False,
|
| 238 |
+
max_length = 1024,
|
| 239 |
+
max_prompt_length = 512,
|
| 240 |
+
max_completion_length = None,
|
| 241 |
+
beta = 0.1,
|
| 242 |
+
label_pad_token_id = -100,
|
| 243 |
+
padding_value = None,
|
| 244 |
+
truncation_mode = 'keep_end',
|
| 245 |
+
disable_dropout = True,
|
| 246 |
+
generate_during_eval = False,
|
| 247 |
+
is_encoder_decoder = None,
|
| 248 |
+
precompute_ref_log_probs = False,
|
| 249 |
+
model_init_kwargs = None,
|
| 250 |
+
ref_model_init_kwargs = None,
|
| 251 |
+
dataset_num_proc = None,
|
| 252 |
+
prompt_sample_size = 1024,
|
| 253 |
+
min_density_ratio = 0.5,
|
| 254 |
+
max_density_ratio = 10.0,
|
| 255 |
+
vllm_sampling_params = None,
|
| 256 |
+
unsloth_num_chunks = -1,
|
| 257 |
+
**kwargs,
|
| 258 |
+
):
|
| 259 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 260 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 261 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 262 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 263 |
+
save_strategy = 'no'
|
| 264 |
+
if dataset_num_proc is None:
|
| 265 |
+
from multiprocessing import cpu_count
|
| 266 |
+
dataset_num_proc = cpu_count()
|
| 267 |
+
|
| 268 |
+
super().__init__(
|
| 269 |
+
output_dir = output_dir,
|
| 270 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 271 |
+
do_train = do_train,
|
| 272 |
+
do_eval = do_eval,
|
| 273 |
+
do_predict = do_predict,
|
| 274 |
+
eval_strategy = eval_strategy,
|
| 275 |
+
prediction_loss_only = prediction_loss_only,
|
| 276 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 277 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 278 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 279 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 280 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 281 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 282 |
+
eval_delay = eval_delay,
|
| 283 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 284 |
+
learning_rate = learning_rate,
|
| 285 |
+
weight_decay = weight_decay,
|
| 286 |
+
adam_beta1 = adam_beta1,
|
| 287 |
+
adam_beta2 = adam_beta2,
|
| 288 |
+
adam_epsilon = adam_epsilon,
|
| 289 |
+
max_grad_norm = max_grad_norm,
|
| 290 |
+
num_train_epochs = num_train_epochs,
|
| 291 |
+
max_steps = max_steps,
|
| 292 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 293 |
+
warmup_ratio = warmup_ratio,
|
| 294 |
+
warmup_steps = warmup_steps,
|
| 295 |
+
log_level = log_level,
|
| 296 |
+
log_level_replica = log_level_replica,
|
| 297 |
+
log_on_each_node = log_on_each_node,
|
| 298 |
+
logging_dir = logging_dir,
|
| 299 |
+
logging_strategy = logging_strategy,
|
| 300 |
+
logging_first_step = logging_first_step,
|
| 301 |
+
logging_steps = logging_steps,
|
| 302 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 303 |
+
save_strategy = save_strategy,
|
| 304 |
+
save_steps = save_steps,
|
| 305 |
+
save_total_limit = save_total_limit,
|
| 306 |
+
save_safetensors = save_safetensors,
|
| 307 |
+
save_on_each_node = save_on_each_node,
|
| 308 |
+
save_only_model = save_only_model,
|
| 309 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 310 |
+
no_cuda = no_cuda,
|
| 311 |
+
use_cpu = use_cpu,
|
| 312 |
+
use_mps_device = use_mps_device,
|
| 313 |
+
seed = seed,
|
| 314 |
+
data_seed = data_seed,
|
| 315 |
+
jit_mode_eval = jit_mode_eval,
|
| 316 |
+
use_ipex = use_ipex,
|
| 317 |
+
bf16 = bf16,
|
| 318 |
+
fp16 = fp16,
|
| 319 |
+
fp16_opt_level = fp16_opt_level,
|
| 320 |
+
half_precision_backend = half_precision_backend,
|
| 321 |
+
bf16_full_eval = bf16_full_eval,
|
| 322 |
+
fp16_full_eval = fp16_full_eval,
|
| 323 |
+
tf32 = tf32,
|
| 324 |
+
local_rank = local_rank,
|
| 325 |
+
ddp_backend = ddp_backend,
|
| 326 |
+
tpu_num_cores = tpu_num_cores,
|
| 327 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 328 |
+
debug = debug,
|
| 329 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 330 |
+
eval_steps = eval_steps,
|
| 331 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 332 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 333 |
+
past_index = past_index,
|
| 334 |
+
run_name = run_name,
|
| 335 |
+
disable_tqdm = disable_tqdm,
|
| 336 |
+
remove_unused_columns = remove_unused_columns,
|
| 337 |
+
label_names = label_names,
|
| 338 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 339 |
+
metric_for_best_model = metric_for_best_model,
|
| 340 |
+
greater_is_better = greater_is_better,
|
| 341 |
+
ignore_data_skip = ignore_data_skip,
|
| 342 |
+
fsdp = fsdp,
|
| 343 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 344 |
+
fsdp_config = fsdp_config,
|
| 345 |
+
tp_size = tp_size,
|
| 346 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 347 |
+
accelerator_config = accelerator_config,
|
| 348 |
+
deepspeed = deepspeed,
|
| 349 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 350 |
+
optim = optim,
|
| 351 |
+
optim_args = optim_args,
|
| 352 |
+
adafactor = adafactor,
|
| 353 |
+
group_by_length = group_by_length,
|
| 354 |
+
length_column_name = length_column_name,
|
| 355 |
+
report_to = report_to,
|
| 356 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 357 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 358 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 359 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 360 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 361 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 362 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 363 |
+
push_to_hub = push_to_hub,
|
| 364 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 365 |
+
hub_model_id = hub_model_id,
|
| 366 |
+
hub_strategy = hub_strategy,
|
| 367 |
+
hub_token = hub_token,
|
| 368 |
+
hub_private_repo = hub_private_repo,
|
| 369 |
+
hub_always_push = hub_always_push,
|
| 370 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 371 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 372 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 373 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 374 |
+
fp16_backend = fp16_backend,
|
| 375 |
+
evaluation_strategy = evaluation_strategy,
|
| 376 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 377 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 378 |
+
push_to_hub_token = push_to_hub_token,
|
| 379 |
+
mp_parameters = mp_parameters,
|
| 380 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 381 |
+
full_determinism = full_determinism,
|
| 382 |
+
torchdynamo = torchdynamo,
|
| 383 |
+
ray_scope = ray_scope,
|
| 384 |
+
ddp_timeout = ddp_timeout,
|
| 385 |
+
torch_compile = torch_compile,
|
| 386 |
+
torch_compile_backend = torch_compile_backend,
|
| 387 |
+
torch_compile_mode = torch_compile_mode,
|
| 388 |
+
dispatch_batches = dispatch_batches,
|
| 389 |
+
split_batches = split_batches,
|
| 390 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 391 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 392 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 393 |
+
optim_target_modules = optim_target_modules,
|
| 394 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 395 |
+
eval_on_start = eval_on_start,
|
| 396 |
+
use_liger_kernel = use_liger_kernel,
|
| 397 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 398 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 399 |
+
max_length = max_length,
|
| 400 |
+
max_prompt_length = max_prompt_length,
|
| 401 |
+
max_completion_length = max_completion_length,
|
| 402 |
+
beta = beta,
|
| 403 |
+
label_pad_token_id = label_pad_token_id,
|
| 404 |
+
padding_value = padding_value,
|
| 405 |
+
truncation_mode = truncation_mode,
|
| 406 |
+
disable_dropout = disable_dropout,
|
| 407 |
+
generate_during_eval = generate_during_eval,
|
| 408 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 409 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
| 410 |
+
model_init_kwargs = model_init_kwargs,
|
| 411 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
| 412 |
+
dataset_num_proc = dataset_num_proc,
|
| 413 |
+
prompt_sample_size = prompt_sample_size,
|
| 414 |
+
min_density_ratio = min_density_ratio,
|
| 415 |
+
max_density_ratio = max_density_ratio,**kwargs)
|
| 416 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 417 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
class _UnslothBCOTrainer(Trainer):
|
| 421 |
+
r""""""
|
| 422 |
+
|
| 423 |
+
_tag_names = ["trl", "bco"]
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 428 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 429 |
+
args: BCOConfig = None,
|
| 430 |
+
train_dataset: Optional[Dataset] = None,
|
| 431 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 432 |
+
processing_class: Optional[
|
| 433 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 434 |
+
] = None,
|
| 435 |
+
data_collator: Optional[DataCollator] = None,
|
| 436 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 437 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 438 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 439 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 440 |
+
peft_config: Optional[dict] = None,
|
| 441 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 442 |
+
model_adapter_name: Optional[str] = None,
|
| 443 |
+
ref_adapter_name: Optional[str] = None,
|
| 444 |
+
embedding_func: Optional[Callable] = None,
|
| 445 |
+
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 446 |
+
):
|
| 447 |
+
if not is_sklearn_available():
|
| 448 |
+
raise ImportError(
|
| 449 |
+
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if type(args) is TrainingArguments:
|
| 453 |
+
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
| 454 |
+
|
| 455 |
+
if not isinstance(model, str) and ref_model is model:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 458 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if args.model_init_kwargs is None:
|
| 462 |
+
model_init_kwargs = {}
|
| 463 |
+
elif not isinstance(model, str):
|
| 464 |
+
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
| 465 |
+
else:
|
| 466 |
+
model_init_kwargs = args.model_init_kwargs
|
| 467 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 468 |
+
if torch_dtype is not None:
|
| 469 |
+
# Convert to `torch.dtype` if an str is passed
|
| 470 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 471 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 472 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 475 |
+
)
|
| 476 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 477 |
+
|
| 478 |
+
if args.ref_model_init_kwargs is None:
|
| 479 |
+
ref_model_init_kwargs = {}
|
| 480 |
+
elif not isinstance(ref_model, str):
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
| 486 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
| 487 |
+
if torch_dtype is not None:
|
| 488 |
+
# Convert to `torch.dtype` if an str is passed
|
| 489 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 490 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 491 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 494 |
+
)
|
| 495 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
| 496 |
+
|
| 497 |
+
if isinstance(model, str):
|
| 498 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 499 |
+
|
| 500 |
+
if isinstance(ref_model, str):
|
| 501 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
| 502 |
+
|
| 503 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 504 |
+
# has been called in order to properly call autocast if needed.
|
| 505 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 506 |
+
|
| 507 |
+
if not is_peft_available() and peft_config is not None:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 510 |
+
)
|
| 511 |
+
elif is_peft_available() and peft_config is not None:
|
| 512 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 513 |
+
if isinstance(model, PeftModel):
|
| 514 |
+
model = model.merge_and_unload()
|
| 515 |
+
|
| 516 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 517 |
+
_support_gc_kwargs = hasattr(
|
| 518 |
+
args, "gradient_checkpointing_kwargs"
|
| 519 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 520 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 524 |
+
|
| 525 |
+
if _support_gc_kwargs:
|
| 526 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 527 |
+
|
| 528 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 529 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 530 |
+
# For backward compatibility with older versions of transformers
|
| 531 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 532 |
+
model.enable_input_require_grads()
|
| 533 |
+
else:
|
| 534 |
+
|
| 535 |
+
def make_inputs_require_grad(module, input, output):
|
| 536 |
+
output.requires_grad_(True)
|
| 537 |
+
|
| 538 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 539 |
+
|
| 540 |
+
# get peft model with the given config
|
| 541 |
+
model = model
|
| 542 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 543 |
+
peft_module_casting_to_bf16(model)
|
| 544 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 545 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 546 |
+
|
| 547 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 548 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 549 |
+
# fail or completely fail.
|
| 550 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 551 |
+
# For backward compatibility with older versions of transformers
|
| 552 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 553 |
+
model.enable_input_require_grads()
|
| 554 |
+
else:
|
| 555 |
+
|
| 556 |
+
def make_inputs_require_grad(module, input, output):
|
| 557 |
+
output.requires_grad_(True)
|
| 558 |
+
|
| 559 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 560 |
+
|
| 561 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 564 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if model is not None:
|
| 568 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 569 |
+
elif args.is_encoder_decoder is None:
|
| 570 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 571 |
+
else:
|
| 572 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 573 |
+
|
| 574 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 575 |
+
self.model_adapter_name = model_adapter_name
|
| 576 |
+
self.ref_adapter_name = ref_adapter_name
|
| 577 |
+
|
| 578 |
+
if ref_model:
|
| 579 |
+
self.ref_model = ref_model
|
| 580 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 581 |
+
# The `model` with adapters turned off will be used as the reference model
|
| 582 |
+
self.ref_model = None
|
| 583 |
+
else:
|
| 584 |
+
self.ref_model = create_reference_model(model)
|
| 585 |
+
|
| 586 |
+
if processing_class is None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 589 |
+
)
|
| 590 |
+
if args.max_length is None:
|
| 591 |
+
warnings.warn(
|
| 592 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
| 593 |
+
"It will be set to `512` by default, but you should do it yourself in the future.",
|
| 594 |
+
UserWarning,
|
| 595 |
+
)
|
| 596 |
+
max_length = 512
|
| 597 |
+
if args.max_length is not None:
|
| 598 |
+
max_length = args.max_length
|
| 599 |
+
|
| 600 |
+
if args.max_prompt_length is None:
|
| 601 |
+
warnings.warn(
|
| 602 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
| 603 |
+
"It will be set to `128` by default, but you should do it yourself in the future.",
|
| 604 |
+
UserWarning,
|
| 605 |
+
)
|
| 606 |
+
max_prompt_length = 128
|
| 607 |
+
if args.max_prompt_length is not None:
|
| 608 |
+
max_prompt_length = args.max_prompt_length
|
| 609 |
+
|
| 610 |
+
max_completion_length = None
|
| 611 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 612 |
+
warnings.warn(
|
| 613 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
| 614 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 615 |
+
UserWarning,
|
| 616 |
+
)
|
| 617 |
+
max_completion_length = 128
|
| 618 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
| 619 |
+
max_completion_length = args.max_completion_length
|
| 620 |
+
|
| 621 |
+
if data_collator is None:
|
| 622 |
+
data_collator = DPODataCollatorWithPadding(
|
| 623 |
+
pad_token_id=processing_class.pad_token_id,
|
| 624 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 625 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if args.remove_unused_columns:
|
| 629 |
+
args.remove_unused_columns = False
|
| 630 |
+
# warn users
|
| 631 |
+
warnings.warn(
|
| 632 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
| 633 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 634 |
+
UserWarning,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
self.use_dpo_data_collator = True
|
| 638 |
+
else:
|
| 639 |
+
self.use_dpo_data_collator = False
|
| 640 |
+
|
| 641 |
+
# Disable dropout in the model and reference model
|
| 642 |
+
if args.disable_dropout:
|
| 643 |
+
disable_dropout_in_model(model)
|
| 644 |
+
if self.ref_model is not None:
|
| 645 |
+
disable_dropout_in_model(self.ref_model)
|
| 646 |
+
|
| 647 |
+
self.max_length = max_length
|
| 648 |
+
self.generate_during_eval = args.generate_during_eval
|
| 649 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 650 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 651 |
+
self.max_prompt_length = max_prompt_length
|
| 652 |
+
self.truncation_mode = args.truncation_mode
|
| 653 |
+
self.max_completion_length = max_completion_length
|
| 654 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 655 |
+
|
| 656 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 657 |
+
# keep track of first called to avoid computation of future calls
|
| 658 |
+
self._precomputed_train_ref_log_probs = False
|
| 659 |
+
self._precomputed_eval_ref_log_probs = False
|
| 660 |
+
|
| 661 |
+
# metric
|
| 662 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 663 |
+
|
| 664 |
+
# BCO parameter
|
| 665 |
+
self.beta = args.beta
|
| 666 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 667 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 668 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 669 |
+
warnings.warn(
|
| 670 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 671 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 672 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 673 |
+
"loss.",
|
| 674 |
+
UserWarning,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Underlying Distribution Matching argument
|
| 678 |
+
self.embedding_func = embedding_func
|
| 679 |
+
self.embedding_tokenizer = embedding_tokenizer
|
| 680 |
+
|
| 681 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 682 |
+
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
| 683 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
| 684 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 685 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 686 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 687 |
+
# issued.
|
| 688 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 689 |
+
|
| 690 |
+
with PartialState().local_main_process_first():
|
| 691 |
+
# Apply the chat template if needed
|
| 692 |
+
train_dataset = train_dataset.map(
|
| 693 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 694 |
+
)
|
| 695 |
+
if eval_dataset is not None:
|
| 696 |
+
eval_dataset = eval_dataset.map(
|
| 697 |
+
maybe_apply_chat_template,
|
| 698 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 699 |
+
num_proc=args.dataset_num_proc,
|
| 700 |
+
)
|
| 701 |
+
# Shuffle the datasets
|
| 702 |
+
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
| 703 |
+
if eval_dataset is not None:
|
| 704 |
+
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
| 705 |
+
# Tokenize and prepare the training datasets
|
| 706 |
+
train_dataset = train_dataset.map(
|
| 707 |
+
_tokenize,
|
| 708 |
+
batched=True,
|
| 709 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 710 |
+
num_proc=args.dataset_num_proc,
|
| 711 |
+
desc="Tokenizing train dataset",
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Prepare the datasets
|
| 715 |
+
fn_kwargs = {
|
| 716 |
+
"prefix": "",
|
| 717 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
| 718 |
+
"tokenizer": processing_class,
|
| 719 |
+
"max_length": self.max_length,
|
| 720 |
+
"truncation_mode": self.truncation_mode,
|
| 721 |
+
"label_pad_token_id": self.label_pad_token_id,
|
| 722 |
+
"max_prompt_length": self.max_prompt_length,
|
| 723 |
+
"max_completion_length": self.max_completion_length,
|
| 724 |
+
}
|
| 725 |
+
train_dataset = train_dataset.map(
|
| 726 |
+
_process_tokens,
|
| 727 |
+
fn_kwargs=fn_kwargs,
|
| 728 |
+
num_proc=args.dataset_num_proc,
|
| 729 |
+
desc="Processing tokenized train dataset",
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
if eval_dataset is not None:
|
| 733 |
+
# Tokenize
|
| 734 |
+
eval_dataset = eval_dataset.map(
|
| 735 |
+
_tokenize,
|
| 736 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 737 |
+
batched=True,
|
| 738 |
+
num_proc=args.dataset_num_proc,
|
| 739 |
+
desc="Tokenizing eval dataset",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# Process
|
| 743 |
+
fn_kwargs = {
|
| 744 |
+
"prefix": "",
|
| 745 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
| 746 |
+
"tokenizer": processing_class,
|
| 747 |
+
"max_length": self.max_length,
|
| 748 |
+
"truncation_mode": self.truncation_mode,
|
| 749 |
+
"label_pad_token_id": self.label_pad_token_id,
|
| 750 |
+
"max_prompt_length": self.max_prompt_length,
|
| 751 |
+
"max_completion_length": self.max_completion_length,
|
| 752 |
+
}
|
| 753 |
+
eval_dataset = eval_dataset.map(
|
| 754 |
+
_process_tokens,
|
| 755 |
+
fn_kwargs=fn_kwargs,
|
| 756 |
+
num_proc=args.dataset_num_proc,
|
| 757 |
+
desc="Processing tokenized eval dataset",
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
desirable = train_dataset.filter(
|
| 761 |
+
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
| 762 |
+
)
|
| 763 |
+
undesirable = train_dataset.filter(
|
| 764 |
+
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
desirable = desirable.shuffle(seed=args.data_seed)
|
| 768 |
+
undesirable = undesirable.shuffle(seed=args.data_seed)
|
| 769 |
+
|
| 770 |
+
super().__init__(
|
| 771 |
+
model=model,
|
| 772 |
+
args=args,
|
| 773 |
+
data_collator=data_collator,
|
| 774 |
+
train_dataset=train_dataset,
|
| 775 |
+
eval_dataset=eval_dataset,
|
| 776 |
+
processing_class=processing_class,
|
| 777 |
+
model_init=model_init,
|
| 778 |
+
compute_metrics=compute_metrics,
|
| 779 |
+
callbacks=callbacks,
|
| 780 |
+
optimizers=optimizers,
|
| 781 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 785 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 786 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 787 |
+
self.model_accepts_loss_kwargs = False
|
| 788 |
+
|
| 789 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 790 |
+
if hasattr(self.model, "add_model_tags"):
|
| 791 |
+
self.model.add_model_tags(self._tag_names)
|
| 792 |
+
|
| 793 |
+
if not hasattr(self, "accelerator"):
|
| 794 |
+
raise AttributeError(
|
| 795 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 799 |
+
if self.is_deepspeed_enabled:
|
| 800 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 801 |
+
raise ValueError(
|
| 802 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
if self.ref_model is None:
|
| 806 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 807 |
+
raise ValueError(
|
| 808 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
if self.is_deepspeed_enabled:
|
| 812 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
| 813 |
+
else:
|
| 814 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 815 |
+
|
| 816 |
+
self.running = RunningMoments(accelerator=self.accelerator)
|
| 817 |
+
|
| 818 |
+
if self.embedding_func is None:
|
| 819 |
+
return
|
| 820 |
+
|
| 821 |
+
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
| 822 |
+
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
| 823 |
+
|
| 824 |
+
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
| 825 |
+
labels = torch.cat(
|
| 826 |
+
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
self.clf = LogisticRegression(class_weight="balanced").fit(
|
| 830 |
+
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
@property
|
| 834 |
+
def match_underlying_distribution(self):
|
| 835 |
+
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
| 836 |
+
|
| 837 |
+
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 838 |
+
"""
|
| 839 |
+
Calculates the probability if the given prompt embedding is from desirable dataset.
|
| 840 |
+
This function calculates the probability in the process and ensemble across processes.
|
| 841 |
+
"""
|
| 842 |
+
dtype = prompt_embeddings.dtype
|
| 843 |
+
device = prompt_embeddings.device
|
| 844 |
+
rank = self.accelerator.process_index
|
| 845 |
+
|
| 846 |
+
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
| 847 |
+
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
| 848 |
+
)
|
| 849 |
+
sample_size = padded_prompt_embeddings.shape[0]
|
| 850 |
+
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
| 851 |
+
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
| 852 |
+
|
| 853 |
+
# cannot predict for all empty values
|
| 854 |
+
if prompt_embeddings.shape[0] == 0:
|
| 855 |
+
return torch.tensor([], device=device, dtype=dtype)
|
| 856 |
+
|
| 857 |
+
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
| 858 |
+
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
| 859 |
+
prob = self.accelerator.reduce(prob, reduction="mean")
|
| 860 |
+
|
| 861 |
+
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
| 862 |
+
prob = prob[nonzero]
|
| 863 |
+
|
| 864 |
+
return prob
|
| 865 |
+
|
| 866 |
+
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
| 867 |
+
"""
|
| 868 |
+
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
| 869 |
+
and applies self.embedding_func
|
| 870 |
+
"""
|
| 871 |
+
input_ids = torch.where(
|
| 872 |
+
input_ids == self.processing_class.pad_token_id,
|
| 873 |
+
self.embedding_tokenizer.pad_token_id,
|
| 874 |
+
input_ids,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
with torch.no_grad():
|
| 878 |
+
embeddings = self.embedding_func(
|
| 879 |
+
input_ids=input_ids,
|
| 880 |
+
attention_mask=attention_mask,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
return embeddings
|
| 884 |
+
|
| 885 |
+
def _get_prompt_embeddings(
|
| 886 |
+
self, batch: dict[str, Union[list, torch.LongTensor]]
|
| 887 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 888 |
+
"""Extract embeddings from frozen embedding model"""
|
| 889 |
+
|
| 890 |
+
if not self.match_underlying_distribution:
|
| 891 |
+
return None, None
|
| 892 |
+
|
| 893 |
+
embeddings = self._vectorize_prompt(
|
| 894 |
+
input_ids=batch["embedding_input_ids"],
|
| 895 |
+
attention_mask=batch["embedding_attention_mask"],
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
| 899 |
+
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
| 900 |
+
|
| 901 |
+
chosen_embeddings = embeddings[chosen_idx, ...]
|
| 902 |
+
rejected_embeddings = embeddings[rejected_idx, ...]
|
| 903 |
+
|
| 904 |
+
return (chosen_embeddings, rejected_embeddings)
|
| 905 |
+
|
| 906 |
+
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
| 907 |
+
"""
|
| 908 |
+
Sample instances from dataset and get prompt embeddings.
|
| 909 |
+
Used for density ratio classifier training.
|
| 910 |
+
"""
|
| 911 |
+
n_samples = min(len(dataset), sample_size)
|
| 912 |
+
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
| 913 |
+
|
| 914 |
+
embedding_dataset = dataset.select(rand_indices)
|
| 915 |
+
|
| 916 |
+
dataloader_params = {
|
| 917 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 918 |
+
"collate_fn": self.data_collator,
|
| 919 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 920 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 921 |
+
"shuffle": False,
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
# prepare dataloader
|
| 925 |
+
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
| 926 |
+
|
| 927 |
+
with torch.no_grad():
|
| 928 |
+
all_embeddings = torch.empty(0)
|
| 929 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
| 930 |
+
embeddings = self._vectorize_prompt(
|
| 931 |
+
input_ids=padded_batch["embedding_input_ids"],
|
| 932 |
+
attention_mask=padded_batch["embedding_attention_mask"],
|
| 933 |
+
)
|
| 934 |
+
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
| 935 |
+
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
| 936 |
+
|
| 937 |
+
return all_embeddings
|
| 938 |
+
|
| 939 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 940 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 941 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 942 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 943 |
+
|
| 944 |
+
if model is not None:
|
| 945 |
+
if hasattr(model, "config"):
|
| 946 |
+
hidden_size = (
|
| 947 |
+
max(model.config.hidden_sizes)
|
| 948 |
+
if getattr(model.config, "hidden_sizes", None)
|
| 949 |
+
else getattr(model.config, "hidden_size", None)
|
| 950 |
+
)
|
| 951 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 952 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 953 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 954 |
+
config_kwargs.update(
|
| 955 |
+
{
|
| 956 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 957 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 958 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 959 |
+
}
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 963 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 964 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 965 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
| 966 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 967 |
+
model.eval()
|
| 968 |
+
return model
|
| 969 |
+
|
| 970 |
+
def _save_optimizer_and_scheduler(self, output_dir):
|
| 971 |
+
super()._save_optimizer_and_scheduler(output_dir)
|
| 972 |
+
|
| 973 |
+
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
| 974 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 975 |
+
|
| 976 |
+
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
| 977 |
+
|
| 978 |
+
if self.match_underlying_distribution:
|
| 979 |
+
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
| 980 |
+
|
| 981 |
+
def _load_optimizer_and_scheduler(self, checkpoint):
|
| 982 |
+
super()._load_optimizer_and_scheduler(checkpoint)
|
| 983 |
+
|
| 984 |
+
if checkpoint is None:
|
| 985 |
+
return
|
| 986 |
+
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
| 987 |
+
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
| 988 |
+
if os.path.isfile(running_file):
|
| 989 |
+
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
| 990 |
+
|
| 991 |
+
if self.match_underlying_distribution:
|
| 992 |
+
clf_file = os.path.join(checkpoint, CLF_NAME)
|
| 993 |
+
if os.path.isfile(running_file):
|
| 994 |
+
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
| 995 |
+
|
| 996 |
+
@contextmanager
|
| 997 |
+
def null_ref_context(self):
|
| 998 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 999 |
+
with (
|
| 1000 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 1001 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 1002 |
+
else nullcontext()
|
| 1003 |
+
):
|
| 1004 |
+
if self.ref_adapter_name:
|
| 1005 |
+
self.model.set_adapter(self.ref_adapter_name)
|
| 1006 |
+
yield
|
| 1007 |
+
if self.ref_adapter_name:
|
| 1008 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
| 1009 |
+
|
| 1010 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 1011 |
+
"""
|
| 1012 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
| 1013 |
+
|
| 1014 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 1015 |
+
"""
|
| 1016 |
+
|
| 1017 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 1018 |
+
dataloader_params = {
|
| 1019 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 1020 |
+
"collate_fn": self.data_collator,
|
| 1021 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 1022 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 1023 |
+
"shuffle": False,
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
# prepare dataloader
|
| 1027 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 1028 |
+
reference_completion_logps = []
|
| 1029 |
+
|
| 1030 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 1031 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1032 |
+
|
| 1033 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1034 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1035 |
+
|
| 1036 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 1037 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
self._precomputed_train_ref_log_probs = True
|
| 1041 |
+
|
| 1042 |
+
return super().get_train_dataloader()
|
| 1043 |
+
|
| 1044 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 1045 |
+
"""
|
| 1046 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 1047 |
+
|
| 1048 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 1049 |
+
|
| 1050 |
+
Args:
|
| 1051 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 1052 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 1053 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 1054 |
+
"""
|
| 1055 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 1056 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 1057 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 1058 |
+
|
| 1059 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 1060 |
+
dataloader_params = {
|
| 1061 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
| 1062 |
+
"collate_fn": self.data_collator,
|
| 1063 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 1064 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 1065 |
+
"shuffle": False,
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
# prepare dataloader
|
| 1069 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 1070 |
+
|
| 1071 |
+
reference_completion_logps = []
|
| 1072 |
+
|
| 1073 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 1074 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1075 |
+
|
| 1076 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1077 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1078 |
+
|
| 1079 |
+
eval_dataset = eval_dataset.add_column(
|
| 1080 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 1084 |
+
if self.eval_dataset is not None:
|
| 1085 |
+
self.eval_dataset = eval_dataset
|
| 1086 |
+
self._precomputed_eval_ref_log_probs = True
|
| 1087 |
+
|
| 1088 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 1089 |
+
|
| 1090 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 1091 |
+
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
| 1092 |
+
with torch.no_grad():
|
| 1093 |
+
if self.ref_model is None:
|
| 1094 |
+
with self.null_ref_context():
|
| 1095 |
+
if self.is_encoder_decoder:
|
| 1096 |
+
completion_logits = self.model(
|
| 1097 |
+
padded_batch["prompt_input_ids"],
|
| 1098 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1099 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1100 |
+
labels=padded_batch["completion_labels"],
|
| 1101 |
+
).logits
|
| 1102 |
+
|
| 1103 |
+
else:
|
| 1104 |
+
completion_logits = self.model(
|
| 1105 |
+
padded_batch["completion_input_ids"],
|
| 1106 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
| 1107 |
+
).logits
|
| 1108 |
+
|
| 1109 |
+
else:
|
| 1110 |
+
if self.is_encoder_decoder:
|
| 1111 |
+
completion_logits = self.ref_model(
|
| 1112 |
+
padded_batch["prompt_input_ids"],
|
| 1113 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1114 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1115 |
+
labels=padded_batch["completion_labels"],
|
| 1116 |
+
).logits
|
| 1117 |
+
|
| 1118 |
+
else:
|
| 1119 |
+
completion_logits = self.ref_model(
|
| 1120 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 1121 |
+
).logits
|
| 1122 |
+
|
| 1123 |
+
completion_logps = self.get_batch_logps(
|
| 1124 |
+
completion_logits,
|
| 1125 |
+
padded_batch["completion_labels"],
|
| 1126 |
+
average_log_prob=False,
|
| 1127 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1128 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
return completion_logps
|
| 1132 |
+
|
| 1133 |
+
@staticmethod
|
| 1134 |
+
def get_batch_logps(
|
| 1135 |
+
logits: torch.FloatTensor,
|
| 1136 |
+
labels: torch.LongTensor,
|
| 1137 |
+
average_log_prob: bool = False,
|
| 1138 |
+
label_pad_token_id: int = -100,
|
| 1139 |
+
is_encoder_decoder: bool = False,
|
| 1140 |
+
) -> torch.FloatTensor:
|
| 1141 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1142 |
+
|
| 1143 |
+
Args:
|
| 1144 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1145 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1146 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1147 |
+
|
| 1148 |
+
Returns:
|
| 1149 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1150 |
+
"""
|
| 1151 |
+
if logits.shape[:-1] != labels.shape:
|
| 1152 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1153 |
+
|
| 1154 |
+
if not is_encoder_decoder:
|
| 1155 |
+
labels = labels[:, 1:].clone()
|
| 1156 |
+
logits = logits[:, :-1, :]
|
| 1157 |
+
else:
|
| 1158 |
+
# Fixes end-dec RuntimeError
|
| 1159 |
+
labels = labels.clone()
|
| 1160 |
+
|
| 1161 |
+
loss_mask = labels != label_pad_token_id
|
| 1162 |
+
|
| 1163 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1164 |
+
labels[labels == label_pad_token_id] = 0
|
| 1165 |
+
|
| 1166 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1167 |
+
|
| 1168 |
+
if average_log_prob:
|
| 1169 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1170 |
+
else:
|
| 1171 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1172 |
+
|
| 1173 |
+
def forward(
|
| 1174 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1175 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1176 |
+
model_kwargs = (
|
| 1177 |
+
{
|
| 1178 |
+
"labels": batch["completion_labels"],
|
| 1179 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
| 1180 |
+
}
|
| 1181 |
+
if self.is_encoder_decoder
|
| 1182 |
+
else {}
|
| 1183 |
+
)
|
| 1184 |
+
if self.aux_loss_enabled:
|
| 1185 |
+
model_kwargs["output_router_logits"] = True
|
| 1186 |
+
|
| 1187 |
+
outputs = model(
|
| 1188 |
+
batch["completion_input_ids"],
|
| 1189 |
+
attention_mask=batch["completion_attention_mask"],
|
| 1190 |
+
**model_kwargs,
|
| 1191 |
+
)
|
| 1192 |
+
completion_logits = outputs.logits
|
| 1193 |
+
|
| 1194 |
+
completion_logps = self.get_batch_logps(
|
| 1195 |
+
completion_logits,
|
| 1196 |
+
batch["completion_labels"],
|
| 1197 |
+
average_log_prob=False,
|
| 1198 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1199 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
| 1203 |
+
raise ValueError(
|
| 1204 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
| 1205 |
+
"examples for which an output sequence was predicted."
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
| 1209 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
| 1210 |
+
|
| 1211 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
| 1212 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
| 1213 |
+
|
| 1214 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
| 1215 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
| 1216 |
+
|
| 1217 |
+
if self.aux_loss_enabled:
|
| 1218 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
| 1219 |
+
else:
|
| 1220 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
| 1221 |
+
|
| 1222 |
+
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 1223 |
+
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
| 1224 |
+
min_ratio = self.args.min_density_ratio
|
| 1225 |
+
max_ratio = self.args.max_density_ratio
|
| 1226 |
+
|
| 1227 |
+
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
| 1228 |
+
|
| 1229 |
+
return weight
|
| 1230 |
+
|
| 1231 |
+
def bco_loss(
|
| 1232 |
+
self,
|
| 1233 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 1234 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 1235 |
+
reference_chosen_logps: torch.FloatTensor,
|
| 1236 |
+
reference_rejected_logps: torch.FloatTensor,
|
| 1237 |
+
chosen_embeddings: Optional[torch.FloatTensor],
|
| 1238 |
+
rejected_embeddings: Optional[torch.FloatTensor],
|
| 1239 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1240 |
+
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
| 1241 |
+
|
| 1242 |
+
Args:
|
| 1243 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1244 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1245 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1246 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1247 |
+
chosen_embeddings: embeddings of desirable prompts
|
| 1248 |
+
rejected_embeddings: embeddings of undesirable prompts
|
| 1249 |
+
|
| 1250 |
+
Returns:
|
| 1251 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
| 1252 |
+
The losses tensor contains the BCO loss for each example in the batch.
|
| 1253 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 1254 |
+
The delta value contains the moving average of all implicit rewards.
|
| 1255 |
+
"""
|
| 1256 |
+
|
| 1257 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1258 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1259 |
+
chosen_rewards = self.beta * chosen_logratios
|
| 1260 |
+
else:
|
| 1261 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1262 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1263 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1264 |
+
|
| 1265 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1266 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1267 |
+
rejected_rewards = self.beta * rejected_logratios
|
| 1268 |
+
else:
|
| 1269 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1270 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1271 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1272 |
+
|
| 1273 |
+
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
| 1274 |
+
self.running.update(rewards)
|
| 1275 |
+
delta = self.running.mean
|
| 1276 |
+
|
| 1277 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1278 |
+
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
| 1279 |
+
|
| 1280 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1281 |
+
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
| 1282 |
+
|
| 1283 |
+
if self.match_underlying_distribution:
|
| 1284 |
+
chosen_weight = torch.ones_like(chosen_losses)
|
| 1285 |
+
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
| 1286 |
+
|
| 1287 |
+
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
| 1288 |
+
else:
|
| 1289 |
+
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
| 1290 |
+
|
| 1291 |
+
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
| 1292 |
+
|
| 1293 |
+
def get_batch_loss_metrics(
|
| 1294 |
+
self,
|
| 1295 |
+
model,
|
| 1296 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1297 |
+
):
|
| 1298 |
+
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
| 1299 |
+
metrics = {}
|
| 1300 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1301 |
+
|
| 1302 |
+
forward_output = self.forward(model, batch)
|
| 1303 |
+
(
|
| 1304 |
+
policy_chosen_logps,
|
| 1305 |
+
policy_rejected_logps,
|
| 1306 |
+
policy_chosen_logits,
|
| 1307 |
+
policy_rejected_logits,
|
| 1308 |
+
) = forward_output[:4]
|
| 1309 |
+
if self.aux_loss_enabled:
|
| 1310 |
+
aux_loss = forward_output[4]
|
| 1311 |
+
|
| 1312 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
| 1313 |
+
if "reference_logps" in batch:
|
| 1314 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
| 1315 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
| 1316 |
+
|
| 1317 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
| 1318 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
| 1319 |
+
else:
|
| 1320 |
+
with torch.no_grad():
|
| 1321 |
+
if self.ref_model is None:
|
| 1322 |
+
with self.null_ref_context():
|
| 1323 |
+
(
|
| 1324 |
+
reference_chosen_logps,
|
| 1325 |
+
reference_rejected_logps,
|
| 1326 |
+
_,
|
| 1327 |
+
_,
|
| 1328 |
+
) = self.forward(self.model, batch)[:4]
|
| 1329 |
+
else:
|
| 1330 |
+
(
|
| 1331 |
+
reference_chosen_logps,
|
| 1332 |
+
reference_rejected_logps,
|
| 1333 |
+
_,
|
| 1334 |
+
_,
|
| 1335 |
+
) = self.forward(self.ref_model, batch)[:4]
|
| 1336 |
+
|
| 1337 |
+
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
| 1338 |
+
|
| 1339 |
+
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
| 1340 |
+
policy_chosen_logps,
|
| 1341 |
+
policy_rejected_logps,
|
| 1342 |
+
reference_chosen_logps,
|
| 1343 |
+
reference_rejected_logps,
|
| 1344 |
+
chosen_embeddings,
|
| 1345 |
+
rejected_embeddings,
|
| 1346 |
+
)
|
| 1347 |
+
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
| 1348 |
+
|
| 1349 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
| 1350 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
| 1351 |
+
|
| 1352 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1353 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1354 |
+
|
| 1355 |
+
if all_num_chosen > 0:
|
| 1356 |
+
metrics["rewards/chosen_sum"] = (
|
| 1357 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1358 |
+
)
|
| 1359 |
+
metrics["logps/chosen_sum"] = (
|
| 1360 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1361 |
+
)
|
| 1362 |
+
metrics["logits/chosen_sum"] = (
|
| 1363 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1364 |
+
)
|
| 1365 |
+
metrics["count/chosen"] = all_num_chosen
|
| 1366 |
+
|
| 1367 |
+
if all_num_rejected > 0:
|
| 1368 |
+
metrics["rewards/rejected_sum"] = (
|
| 1369 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1370 |
+
)
|
| 1371 |
+
metrics["logps/rejected_sum"] = (
|
| 1372 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1373 |
+
)
|
| 1374 |
+
metrics["logits/rejected_sum"] = (
|
| 1375 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1376 |
+
)
|
| 1377 |
+
metrics["count/rejected"] = all_num_rejected
|
| 1378 |
+
|
| 1379 |
+
loss = losses.nanmean()
|
| 1380 |
+
if self.aux_loss_enabled:
|
| 1381 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1382 |
+
|
| 1383 |
+
return loss, metrics
|
| 1384 |
+
|
| 1385 |
+
def compute_loss(
|
| 1386 |
+
self,
|
| 1387 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1388 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1389 |
+
return_outputs=False,
|
| 1390 |
+
num_items_in_batch=None,
|
| 1391 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1392 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1393 |
+
|
| 1394 |
+
with compute_loss_context_manager:
|
| 1395 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1396 |
+
|
| 1397 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1398 |
+
loss = loss.to(self.args.device)
|
| 1399 |
+
# force log the metrics
|
| 1400 |
+
if self.accelerator.is_main_process:
|
| 1401 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1402 |
+
|
| 1403 |
+
if return_outputs:
|
| 1404 |
+
return (loss, metrics)
|
| 1405 |
+
return loss
|
| 1406 |
+
|
| 1407 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1408 |
+
for key, value in metrics.items():
|
| 1409 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1410 |
+
|
| 1411 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 1412 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 1413 |
+
return None
|
| 1414 |
+
return SequentialSampler(self.train_dataset)
|
| 1415 |
+
|
| 1416 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1417 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1418 |
+
|
| 1419 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1420 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1421 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1422 |
+
with generate_context_manager:
|
| 1423 |
+
policy_output = model.generate(
|
| 1424 |
+
input_ids=batch["prompt_input_ids"],
|
| 1425 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1426 |
+
max_length=self.max_length,
|
| 1427 |
+
do_sample=True,
|
| 1428 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
# if reference_output in batch use that otherwise use the reference model
|
| 1432 |
+
if "reference_output" in batch:
|
| 1433 |
+
reference_output = batch["reference_output"]
|
| 1434 |
+
else:
|
| 1435 |
+
if self.ref_model is None:
|
| 1436 |
+
with self.null_ref_context():
|
| 1437 |
+
reference_output = self.model.generate(
|
| 1438 |
+
input_ids=batch["prompt_input_ids"],
|
| 1439 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1440 |
+
max_length=self.max_length,
|
| 1441 |
+
do_sample=True,
|
| 1442 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1443 |
+
)
|
| 1444 |
+
else:
|
| 1445 |
+
reference_output = self.ref_model.generate(
|
| 1446 |
+
input_ids=batch["prompt_input_ids"],
|
| 1447 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1448 |
+
max_length=self.max_length,
|
| 1449 |
+
do_sample=True,
|
| 1450 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1454 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1455 |
+
|
| 1456 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1457 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1458 |
+
|
| 1459 |
+
return policy_output_decoded, reference_output_decoded
|
| 1460 |
+
|
| 1461 |
+
def prediction_step(
|
| 1462 |
+
self,
|
| 1463 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1464 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1465 |
+
prediction_loss_only: bool,
|
| 1466 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1467 |
+
):
|
| 1468 |
+
if ignore_keys is None:
|
| 1469 |
+
if hasattr(model, "config"):
|
| 1470 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1471 |
+
else:
|
| 1472 |
+
ignore_keys = []
|
| 1473 |
+
|
| 1474 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1475 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1476 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1477 |
+
|
| 1478 |
+
# force log the metrics
|
| 1479 |
+
if self.accelerator.is_main_process:
|
| 1480 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1481 |
+
|
| 1482 |
+
if prediction_loss_only:
|
| 1483 |
+
return (loss.detach(), None, None)
|
| 1484 |
+
|
| 1485 |
+
# logits for the chosen and rejected samples from model
|
| 1486 |
+
logits_dict = {
|
| 1487 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
| 1488 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
| 1489 |
+
}
|
| 1490 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
| 1491 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
| 1492 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1493 |
+
|
| 1494 |
+
return (loss.detach(), logits, labels)
|
| 1495 |
+
|
| 1496 |
+
def evaluation_loop(
|
| 1497 |
+
self,
|
| 1498 |
+
dataloader: DataLoader,
|
| 1499 |
+
description: str,
|
| 1500 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1501 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1502 |
+
metric_key_prefix: str = "eval",
|
| 1503 |
+
) -> EvalLoopOutput:
|
| 1504 |
+
"""
|
| 1505 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1506 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1507 |
+
|
| 1508 |
+
Works both with or without labels.
|
| 1509 |
+
"""
|
| 1510 |
+
|
| 1511 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1512 |
+
if self.generate_during_eval:
|
| 1513 |
+
# Generate random indices within the range of the total number of samples
|
| 1514 |
+
num_samples = len(dataloader.dataset)
|
| 1515 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1516 |
+
|
| 1517 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1518 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1519 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1520 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1521 |
+
|
| 1522 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
| 1523 |
+
target_batch = {
|
| 1524 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
| 1525 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
| 1526 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
| 1527 |
+
}
|
| 1528 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1529 |
+
|
| 1530 |
+
table = pd.DataFrame(
|
| 1531 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
| 1532 |
+
data=[
|
| 1533 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1534 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
| 1535 |
+
],
|
| 1536 |
+
)
|
| 1537 |
+
if "wandb" in self.args.report_to:
|
| 1538 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1539 |
+
|
| 1540 |
+
if "comet_ml" in self.args.report_to:
|
| 1541 |
+
log_table_to_comet_experiment(
|
| 1542 |
+
name="game_log.csv",
|
| 1543 |
+
table=table,
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
# Base evaluation
|
| 1547 |
+
initial_output = super().evaluation_loop(
|
| 1548 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
return initial_output
|
| 1552 |
+
|
| 1553 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1554 |
+
"""
|
| 1555 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1556 |
+
|
| 1557 |
+
Args:
|
| 1558 |
+
logs (`dict[str, float]`):
|
| 1559 |
+
The values to log.
|
| 1560 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1561 |
+
Start time of the training.
|
| 1562 |
+
"""
|
| 1563 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1564 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1565 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
| 1566 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1567 |
+
# accumulate average metrics from sums and lengths
|
| 1568 |
+
for split in ["chosen", "rejected"]:
|
| 1569 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1570 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1571 |
+
for metric in ["rewards", "logps", "logits"]:
|
| 1572 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
| 1573 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1574 |
+
/ count_sum
|
| 1575 |
+
)
|
| 1576 |
+
# delete obsolete metric
|
| 1577 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1578 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1579 |
+
# calculate reward margin
|
| 1580 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1581 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1582 |
+
# Add averaged stored metrics to logs
|
| 1583 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1584 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1585 |
+
del self._stored_metrics[train_eval]
|
| 1586 |
+
|
| 1587 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1588 |
+
return super().log(logs, start_time)
|
| 1589 |
+
else: # transformers<=4.46
|
| 1590 |
+
return super().log(logs)
|
| 1591 |
+
|
| 1592 |
+
def create_model_card(
|
| 1593 |
+
self,
|
| 1594 |
+
model_name: Optional[str] = None,
|
| 1595 |
+
dataset_name: Optional[str] = None,
|
| 1596 |
+
tags: Union[str, list[str], None] = None,
|
| 1597 |
+
):
|
| 1598 |
+
"""
|
| 1599 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1600 |
+
|
| 1601 |
+
Args:
|
| 1602 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1603 |
+
Name of the model.
|
| 1604 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1605 |
+
Name of the dataset used for training.
|
| 1606 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1607 |
+
Tags to be associated with the model card.
|
| 1608 |
+
"""
|
| 1609 |
+
if not self.is_world_process_zero():
|
| 1610 |
+
return
|
| 1611 |
+
|
| 1612 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1613 |
+
base_model = self.model.config._name_or_path
|
| 1614 |
+
else:
|
| 1615 |
+
base_model = None
|
| 1616 |
+
|
| 1617 |
+
tags = tags or []
|
| 1618 |
+
if isinstance(tags, str):
|
| 1619 |
+
tags = [tags]
|
| 1620 |
+
|
| 1621 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1622 |
+
tags.append("unsloth")
|
| 1623 |
+
|
| 1624 |
+
citation = textwrap.dedent("""\
|
| 1625 |
+
@article{jung2024binary,
|
| 1626 |
+
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
| 1627 |
+
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
| 1628 |
+
year = 2024,
|
| 1629 |
+
eprint = {arXiv:2404.04656}
|
| 1630 |
+
}""")
|
| 1631 |
+
|
| 1632 |
+
model_card = generate_model_card(
|
| 1633 |
+
base_model=base_model,
|
| 1634 |
+
model_name=model_name,
|
| 1635 |
+
hub_model_id=self.hub_model_id,
|
| 1636 |
+
dataset_name=dataset_name,
|
| 1637 |
+
tags=tags,
|
| 1638 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1639 |
+
comet_url=get_comet_experiment_url(),
|
| 1640 |
+
trainer_name="BCO",
|
| 1641 |
+
trainer_citation=citation,
|
| 1642 |
+
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
| 1643 |
+
paper_id="2404.04656",
|
| 1644 |
+
)
|
| 1645 |
+
|
| 1646 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1647 |
+
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
| 1648 |
+
"""
|
| 1649 |
+
|
| 1650 |
+
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
| 1651 |
+
|
| 1652 |
+
Args:
|
| 1653 |
+
model (`transformers.PreTrainedModel`):
|
| 1654 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1655 |
+
ref_model (`PreTrainedModelWrapper`):
|
| 1656 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 1657 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 1658 |
+
args (`BCOConfig`):
|
| 1659 |
+
The arguments to use for training.
|
| 1660 |
+
train_dataset (`datasets.Dataset`):
|
| 1661 |
+
The dataset to use for training.
|
| 1662 |
+
eval_dataset (`datasets.Dataset`):
|
| 1663 |
+
The dataset to use for evaluation.
|
| 1664 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1665 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1666 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1667 |
+
reuse the fine-tuned model.
|
| 1668 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
| 1669 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1670 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1671 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1672 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1673 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1674 |
+
The callbacks to use for training.
|
| 1675 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1676 |
+
The optimizer and scheduler to use for training.
|
| 1677 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1678 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1679 |
+
peft_config (`dict`, defaults to `None`):
|
| 1680 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1681 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1682 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1683 |
+
a dictionary string to metric values.
|
| 1684 |
+
model_adapter_name (`str`, defaults to `None`):
|
| 1685 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 1686 |
+
ref_adapter_name (`str`, defaults to `None`):
|
| 1687 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 1688 |
+
|
| 1689 |
+
"""
|
| 1690 |
+
def __init__(
|
| 1691 |
+
self,
|
| 1692 |
+
model = None,
|
| 1693 |
+
ref_model = None,
|
| 1694 |
+
args = None,
|
| 1695 |
+
train_dataset = None,
|
| 1696 |
+
eval_dataset = None,
|
| 1697 |
+
processing_class = None,
|
| 1698 |
+
data_collator = None,
|
| 1699 |
+
model_init = None,
|
| 1700 |
+
callbacks = None,
|
| 1701 |
+
preprocess_logits_for_metrics = None,
|
| 1702 |
+
peft_config = None,
|
| 1703 |
+
compute_metrics = None,
|
| 1704 |
+
model_adapter_name = None,
|
| 1705 |
+
ref_adapter_name = None,
|
| 1706 |
+
embedding_func = None,
|
| 1707 |
+
embedding_tokenizer = None,
|
| 1708 |
+
**kwargs
|
| 1709 |
+
):
|
| 1710 |
+
if args is None: args = UnslothBCOConfig()
|
| 1711 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1712 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1713 |
+
force_float32 = False
|
| 1714 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1715 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1716 |
+
force_float32 = True
|
| 1717 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1718 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1719 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1720 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1721 |
+
dtype = _get_dtype(dtype)
|
| 1722 |
+
float16 = dtype == torch.float16
|
| 1723 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1724 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1725 |
+
if force_float32:
|
| 1726 |
+
args.fp16 = False
|
| 1727 |
+
args.bf16 = False
|
| 1728 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1729 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1730 |
+
args.fp16 = float16
|
| 1731 |
+
args.bf16 = not float16
|
| 1732 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1733 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1734 |
+
args.eval_strategy = 'steps'
|
| 1735 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1736 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1737 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1738 |
+
from transformers import __version__ as transformers_version
|
| 1739 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1740 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1741 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1742 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1743 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1744 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1745 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1746 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1747 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1748 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1749 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1750 |
+
if force_float32:
|
| 1751 |
+
args.bf16_full_eval = False
|
| 1752 |
+
args.fp16_full_eval = False
|
| 1753 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1754 |
+
args.bf16_full_eval = True
|
| 1755 |
+
args.fp16_full_eval = False
|
| 1756 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1757 |
+
args.bf16_full_eval = args.bf16
|
| 1758 |
+
args.fp16_full_eval = args.fp16
|
| 1759 |
+
_output_logits = False
|
| 1760 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1761 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1762 |
+
if _output_logits:
|
| 1763 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1764 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1765 |
+
pass
|
| 1766 |
+
else:
|
| 1767 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1768 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1769 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1770 |
+
max_seq_length = model.max_seq_length
|
| 1771 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1772 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1773 |
+
model.for_training()
|
| 1774 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1775 |
+
if 'processing_class' in locals():
|
| 1776 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1777 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1778 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1779 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1780 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1781 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1782 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1783 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1784 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1785 |
+
else:
|
| 1786 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1787 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1788 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1789 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1790 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1791 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1792 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1793 |
+
else:
|
| 1794 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1795 |
+
other_metrics = []
|
| 1796 |
+
|
| 1797 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1798 |
+
PatchRLStatistics('bco_trainer', other_metrics)
|
| 1799 |
+
|
| 1800 |
+
super().__init__(
|
| 1801 |
+
model = model,
|
| 1802 |
+
ref_model = ref_model,
|
| 1803 |
+
args = args,
|
| 1804 |
+
train_dataset = train_dataset,
|
| 1805 |
+
eval_dataset = eval_dataset,
|
| 1806 |
+
processing_class = processing_class,
|
| 1807 |
+
data_collator = data_collator,
|
| 1808 |
+
model_init = model_init,
|
| 1809 |
+
callbacks = callbacks,
|
| 1810 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1811 |
+
peft_config = peft_config,
|
| 1812 |
+
compute_metrics = compute_metrics,
|
| 1813 |
+
model_adapter_name = model_adapter_name,
|
| 1814 |
+
ref_adapter_name = ref_adapter_name,
|
| 1815 |
+
embedding_func = embedding_func,
|
| 1816 |
+
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
| 1817 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1818 |
+
self.neftune_hook_handle.remove()
|
| 1819 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1820 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1821 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1822 |
+
pass
|
| 1823 |
+
|
| 1824 |
+
pass
|
unsloth_compiled_cache/UnslothCPOTrainer.py
ADDED
|
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothCPOConfig(CPOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`CPOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 55 |
+
[`~transformers.TrainingArguments`].
|
| 56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 58 |
+
to use the default data collator.
|
| 59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 63 |
+
and your model is an encoder-decoder.
|
| 64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 66 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 67 |
+
the [paper](https://huggingface.co/papers/2310.12036).
|
| 68 |
+
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
| 69 |
+
Label smoothing factor. This argument is required if you want to use the default data collator.
|
| 70 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 71 |
+
Type of loss to use. Possible values are:
|
| 72 |
+
|
| 73 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 74 |
+
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
|
| 75 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 76 |
+
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
| 77 |
+
|
| 78 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether to disable dropout in the model.
|
| 80 |
+
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
| 81 |
+
Weight of the BC regularizer in CPO training.
|
| 82 |
+
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
| 83 |
+
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
| 84 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 85 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 86 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 87 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 88 |
+
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
| 89 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 90 |
+
This argument is required if you want to use the default data collator.
|
| 91 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 92 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 93 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 94 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 95 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 96 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 98 |
+
string.
|
| 99 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 100 |
+
Number of processes to use for processing the dataset.
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 104 |
+
default = None,
|
| 105 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 106 |
+
)
|
| 107 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 108 |
+
default = -1,
|
| 109 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 110 |
+
)
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
output_dir = None,
|
| 114 |
+
overwrite_output_dir = None,
|
| 115 |
+
do_train = False,
|
| 116 |
+
do_eval = False,
|
| 117 |
+
do_predict = False,
|
| 118 |
+
eval_strategy = 'no',
|
| 119 |
+
prediction_loss_only = False,
|
| 120 |
+
per_device_train_batch_size = 4,
|
| 121 |
+
per_device_eval_batch_size = 4,
|
| 122 |
+
per_gpu_train_batch_size = None,
|
| 123 |
+
per_gpu_eval_batch_size = None,
|
| 124 |
+
gradient_accumulation_steps = 2,
|
| 125 |
+
eval_accumulation_steps = 2,
|
| 126 |
+
eval_delay = 0,
|
| 127 |
+
torch_empty_cache_steps = 250,
|
| 128 |
+
learning_rate = 5e-05,
|
| 129 |
+
weight_decay = 0.01,
|
| 130 |
+
adam_beta1 = 0.9,
|
| 131 |
+
adam_beta2 = 0.999,
|
| 132 |
+
adam_epsilon = 1e-08,
|
| 133 |
+
max_grad_norm = 1.0,
|
| 134 |
+
num_train_epochs = 3.0,
|
| 135 |
+
max_steps = -1,
|
| 136 |
+
lr_scheduler_type = 'linear',
|
| 137 |
+
warmup_ratio = 0.1,
|
| 138 |
+
warmup_steps = 0,
|
| 139 |
+
log_level = 'passive',
|
| 140 |
+
log_level_replica = 'warning',
|
| 141 |
+
log_on_each_node = True,
|
| 142 |
+
logging_dir = None,
|
| 143 |
+
logging_strategy = 'steps',
|
| 144 |
+
logging_first_step = False,
|
| 145 |
+
logging_steps = 1,
|
| 146 |
+
logging_nan_inf_filter = False,
|
| 147 |
+
save_strategy = 'steps',
|
| 148 |
+
save_steps = 500,
|
| 149 |
+
save_total_limit = None,
|
| 150 |
+
save_safetensors = True,
|
| 151 |
+
save_on_each_node = False,
|
| 152 |
+
save_only_model = False,
|
| 153 |
+
restore_callback_states_from_checkpoint = False,
|
| 154 |
+
no_cuda = False,
|
| 155 |
+
use_cpu = False,
|
| 156 |
+
use_mps_device = False,
|
| 157 |
+
seed = 3407,
|
| 158 |
+
data_seed = 3407,
|
| 159 |
+
jit_mode_eval = False,
|
| 160 |
+
use_ipex = False,
|
| 161 |
+
bf16 = False,
|
| 162 |
+
fp16 = False,
|
| 163 |
+
fp16_opt_level = 'O1',
|
| 164 |
+
half_precision_backend = 'auto',
|
| 165 |
+
bf16_full_eval = False,
|
| 166 |
+
fp16_full_eval = False,
|
| 167 |
+
tf32 = None,
|
| 168 |
+
local_rank = -1,
|
| 169 |
+
ddp_backend = None,
|
| 170 |
+
tpu_num_cores = None,
|
| 171 |
+
tpu_metrics_debug = False,
|
| 172 |
+
debug = '',
|
| 173 |
+
dataloader_drop_last = False,
|
| 174 |
+
eval_steps = None,
|
| 175 |
+
dataloader_num_workers = 0,
|
| 176 |
+
dataloader_prefetch_factor = None,
|
| 177 |
+
past_index = -1,
|
| 178 |
+
run_name = None,
|
| 179 |
+
disable_tqdm = None,
|
| 180 |
+
remove_unused_columns = True,
|
| 181 |
+
label_names = None,
|
| 182 |
+
load_best_model_at_end = False,
|
| 183 |
+
metric_for_best_model = None,
|
| 184 |
+
greater_is_better = None,
|
| 185 |
+
ignore_data_skip = False,
|
| 186 |
+
fsdp = '',
|
| 187 |
+
fsdp_min_num_params = 0,
|
| 188 |
+
fsdp_config = None,
|
| 189 |
+
tp_size = 0,
|
| 190 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 191 |
+
accelerator_config = None,
|
| 192 |
+
deepspeed = None,
|
| 193 |
+
label_smoothing_factor = 0.0,
|
| 194 |
+
optim = 'adamw_8bit',
|
| 195 |
+
optim_args = None,
|
| 196 |
+
adafactor = False,
|
| 197 |
+
group_by_length = False,
|
| 198 |
+
length_column_name = 'length',
|
| 199 |
+
report_to = None,
|
| 200 |
+
ddp_find_unused_parameters = None,
|
| 201 |
+
ddp_bucket_cap_mb = None,
|
| 202 |
+
ddp_broadcast_buffers = None,
|
| 203 |
+
dataloader_pin_memory = True,
|
| 204 |
+
dataloader_persistent_workers = False,
|
| 205 |
+
skip_memory_metrics = True,
|
| 206 |
+
use_legacy_prediction_loop = False,
|
| 207 |
+
push_to_hub = False,
|
| 208 |
+
resume_from_checkpoint = None,
|
| 209 |
+
hub_model_id = None,
|
| 210 |
+
hub_strategy = 'every_save',
|
| 211 |
+
hub_token = None,
|
| 212 |
+
hub_private_repo = None,
|
| 213 |
+
hub_always_push = False,
|
| 214 |
+
gradient_checkpointing = False,
|
| 215 |
+
gradient_checkpointing_kwargs = None,
|
| 216 |
+
include_inputs_for_metrics = False,
|
| 217 |
+
eval_do_concat_batches = True,
|
| 218 |
+
fp16_backend = 'auto',
|
| 219 |
+
evaluation_strategy = None,
|
| 220 |
+
push_to_hub_model_id = None,
|
| 221 |
+
push_to_hub_organization = None,
|
| 222 |
+
push_to_hub_token = None,
|
| 223 |
+
mp_parameters = '',
|
| 224 |
+
auto_find_batch_size = False,
|
| 225 |
+
full_determinism = False,
|
| 226 |
+
torchdynamo = None,
|
| 227 |
+
ray_scope = 'last',
|
| 228 |
+
ddp_timeout = 1800,
|
| 229 |
+
torch_compile = False,
|
| 230 |
+
torch_compile_backend = None,
|
| 231 |
+
torch_compile_mode = None,
|
| 232 |
+
dispatch_batches = None,
|
| 233 |
+
split_batches = None,
|
| 234 |
+
include_tokens_per_second = False,
|
| 235 |
+
include_num_input_tokens_seen = False,
|
| 236 |
+
neftune_noise_alpha = None,
|
| 237 |
+
optim_target_modules = None,
|
| 238 |
+
batch_eval_metrics = False,
|
| 239 |
+
eval_on_start = False,
|
| 240 |
+
use_liger_kernel = False,
|
| 241 |
+
eval_use_gather_object = False,
|
| 242 |
+
average_tokens_across_devices = False,
|
| 243 |
+
max_length = 1024,
|
| 244 |
+
max_prompt_length = 512,
|
| 245 |
+
max_completion_length = None,
|
| 246 |
+
beta = 0.1,
|
| 247 |
+
label_smoothing = 0.0,
|
| 248 |
+
loss_type = 'sigmoid',
|
| 249 |
+
disable_dropout = True,
|
| 250 |
+
cpo_alpha = 1.0,
|
| 251 |
+
simpo_gamma = 0.5,
|
| 252 |
+
label_pad_token_id = -100,
|
| 253 |
+
padding_value = None,
|
| 254 |
+
truncation_mode = 'keep_end',
|
| 255 |
+
generate_during_eval = False,
|
| 256 |
+
is_encoder_decoder = None,
|
| 257 |
+
model_init_kwargs = None,
|
| 258 |
+
dataset_num_proc = None,
|
| 259 |
+
vllm_sampling_params = None,
|
| 260 |
+
unsloth_num_chunks = -1,
|
| 261 |
+
**kwargs,
|
| 262 |
+
):
|
| 263 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 264 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 265 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 266 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 267 |
+
save_strategy = 'no'
|
| 268 |
+
if dataset_num_proc is None:
|
| 269 |
+
from multiprocessing import cpu_count
|
| 270 |
+
dataset_num_proc = cpu_count()
|
| 271 |
+
|
| 272 |
+
super().__init__(
|
| 273 |
+
output_dir = output_dir,
|
| 274 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 275 |
+
do_train = do_train,
|
| 276 |
+
do_eval = do_eval,
|
| 277 |
+
do_predict = do_predict,
|
| 278 |
+
eval_strategy = eval_strategy,
|
| 279 |
+
prediction_loss_only = prediction_loss_only,
|
| 280 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 281 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 282 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 283 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 284 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 285 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 286 |
+
eval_delay = eval_delay,
|
| 287 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 288 |
+
learning_rate = learning_rate,
|
| 289 |
+
weight_decay = weight_decay,
|
| 290 |
+
adam_beta1 = adam_beta1,
|
| 291 |
+
adam_beta2 = adam_beta2,
|
| 292 |
+
adam_epsilon = adam_epsilon,
|
| 293 |
+
max_grad_norm = max_grad_norm,
|
| 294 |
+
num_train_epochs = num_train_epochs,
|
| 295 |
+
max_steps = max_steps,
|
| 296 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 297 |
+
warmup_ratio = warmup_ratio,
|
| 298 |
+
warmup_steps = warmup_steps,
|
| 299 |
+
log_level = log_level,
|
| 300 |
+
log_level_replica = log_level_replica,
|
| 301 |
+
log_on_each_node = log_on_each_node,
|
| 302 |
+
logging_dir = logging_dir,
|
| 303 |
+
logging_strategy = logging_strategy,
|
| 304 |
+
logging_first_step = logging_first_step,
|
| 305 |
+
logging_steps = logging_steps,
|
| 306 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 307 |
+
save_strategy = save_strategy,
|
| 308 |
+
save_steps = save_steps,
|
| 309 |
+
save_total_limit = save_total_limit,
|
| 310 |
+
save_safetensors = save_safetensors,
|
| 311 |
+
save_on_each_node = save_on_each_node,
|
| 312 |
+
save_only_model = save_only_model,
|
| 313 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 314 |
+
no_cuda = no_cuda,
|
| 315 |
+
use_cpu = use_cpu,
|
| 316 |
+
use_mps_device = use_mps_device,
|
| 317 |
+
seed = seed,
|
| 318 |
+
data_seed = data_seed,
|
| 319 |
+
jit_mode_eval = jit_mode_eval,
|
| 320 |
+
use_ipex = use_ipex,
|
| 321 |
+
bf16 = bf16,
|
| 322 |
+
fp16 = fp16,
|
| 323 |
+
fp16_opt_level = fp16_opt_level,
|
| 324 |
+
half_precision_backend = half_precision_backend,
|
| 325 |
+
bf16_full_eval = bf16_full_eval,
|
| 326 |
+
fp16_full_eval = fp16_full_eval,
|
| 327 |
+
tf32 = tf32,
|
| 328 |
+
local_rank = local_rank,
|
| 329 |
+
ddp_backend = ddp_backend,
|
| 330 |
+
tpu_num_cores = tpu_num_cores,
|
| 331 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 332 |
+
debug = debug,
|
| 333 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 334 |
+
eval_steps = eval_steps,
|
| 335 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 336 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 337 |
+
past_index = past_index,
|
| 338 |
+
run_name = run_name,
|
| 339 |
+
disable_tqdm = disable_tqdm,
|
| 340 |
+
remove_unused_columns = remove_unused_columns,
|
| 341 |
+
label_names = label_names,
|
| 342 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 343 |
+
metric_for_best_model = metric_for_best_model,
|
| 344 |
+
greater_is_better = greater_is_better,
|
| 345 |
+
ignore_data_skip = ignore_data_skip,
|
| 346 |
+
fsdp = fsdp,
|
| 347 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 348 |
+
fsdp_config = fsdp_config,
|
| 349 |
+
tp_size = tp_size,
|
| 350 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 351 |
+
accelerator_config = accelerator_config,
|
| 352 |
+
deepspeed = deepspeed,
|
| 353 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 354 |
+
optim = optim,
|
| 355 |
+
optim_args = optim_args,
|
| 356 |
+
adafactor = adafactor,
|
| 357 |
+
group_by_length = group_by_length,
|
| 358 |
+
length_column_name = length_column_name,
|
| 359 |
+
report_to = report_to,
|
| 360 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 361 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 362 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 363 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 364 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 365 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 366 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 367 |
+
push_to_hub = push_to_hub,
|
| 368 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 369 |
+
hub_model_id = hub_model_id,
|
| 370 |
+
hub_strategy = hub_strategy,
|
| 371 |
+
hub_token = hub_token,
|
| 372 |
+
hub_private_repo = hub_private_repo,
|
| 373 |
+
hub_always_push = hub_always_push,
|
| 374 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 375 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 376 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 377 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 378 |
+
fp16_backend = fp16_backend,
|
| 379 |
+
evaluation_strategy = evaluation_strategy,
|
| 380 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 381 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 382 |
+
push_to_hub_token = push_to_hub_token,
|
| 383 |
+
mp_parameters = mp_parameters,
|
| 384 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 385 |
+
full_determinism = full_determinism,
|
| 386 |
+
torchdynamo = torchdynamo,
|
| 387 |
+
ray_scope = ray_scope,
|
| 388 |
+
ddp_timeout = ddp_timeout,
|
| 389 |
+
torch_compile = torch_compile,
|
| 390 |
+
torch_compile_backend = torch_compile_backend,
|
| 391 |
+
torch_compile_mode = torch_compile_mode,
|
| 392 |
+
dispatch_batches = dispatch_batches,
|
| 393 |
+
split_batches = split_batches,
|
| 394 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 395 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 396 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 397 |
+
optim_target_modules = optim_target_modules,
|
| 398 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 399 |
+
eval_on_start = eval_on_start,
|
| 400 |
+
use_liger_kernel = use_liger_kernel,
|
| 401 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 402 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 403 |
+
max_length = max_length,
|
| 404 |
+
max_prompt_length = max_prompt_length,
|
| 405 |
+
max_completion_length = max_completion_length,
|
| 406 |
+
beta = beta,
|
| 407 |
+
label_smoothing = label_smoothing,
|
| 408 |
+
loss_type = loss_type,
|
| 409 |
+
disable_dropout = disable_dropout,
|
| 410 |
+
cpo_alpha = cpo_alpha,
|
| 411 |
+
simpo_gamma = simpo_gamma,
|
| 412 |
+
label_pad_token_id = label_pad_token_id,
|
| 413 |
+
padding_value = padding_value,
|
| 414 |
+
truncation_mode = truncation_mode,
|
| 415 |
+
generate_during_eval = generate_during_eval,
|
| 416 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 417 |
+
model_init_kwargs = model_init_kwargs,
|
| 418 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 419 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 420 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 421 |
+
pass
|
| 422 |
+
|
| 423 |
+
class _UnslothCPOTrainer(Trainer):
|
| 424 |
+
r""""""
|
| 425 |
+
|
| 426 |
+
_tag_names = ["trl", "cpo"]
|
| 427 |
+
|
| 428 |
+
def __init__(
|
| 429 |
+
self,
|
| 430 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 431 |
+
args: Optional[CPOConfig] = None,
|
| 432 |
+
data_collator: Optional[DataCollator] = None,
|
| 433 |
+
train_dataset: Optional[Dataset] = None,
|
| 434 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 435 |
+
processing_class: Optional[
|
| 436 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 437 |
+
] = None,
|
| 438 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 439 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 440 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 441 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 442 |
+
peft_config: Optional[dict] = None,
|
| 443 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 444 |
+
):
|
| 445 |
+
if args.model_init_kwargs is None:
|
| 446 |
+
model_init_kwargs = {}
|
| 447 |
+
elif not isinstance(model, str):
|
| 448 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
| 449 |
+
else:
|
| 450 |
+
model_init_kwargs = args.model_init_kwargs
|
| 451 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 452 |
+
if torch_dtype is not None:
|
| 453 |
+
# Convert to `torch.dtype` if an str is passed
|
| 454 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 455 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 456 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 457 |
+
raise ValueError(
|
| 458 |
+
f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 459 |
+
)
|
| 460 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 461 |
+
|
| 462 |
+
if isinstance(model, str):
|
| 463 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 464 |
+
|
| 465 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 466 |
+
# has been called in order to properly call autocast if needed.
|
| 467 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 468 |
+
|
| 469 |
+
if not is_peft_available() and peft_config is not None:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 472 |
+
)
|
| 473 |
+
elif is_peft_available() and peft_config is not None:
|
| 474 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 475 |
+
if isinstance(model, PeftModel):
|
| 476 |
+
model = model.merge_and_unload()
|
| 477 |
+
|
| 478 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 479 |
+
_support_gc_kwargs = hasattr(
|
| 480 |
+
args, "gradient_checkpointing_kwargs"
|
| 481 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 482 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 486 |
+
|
| 487 |
+
if _support_gc_kwargs:
|
| 488 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 489 |
+
|
| 490 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 491 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 492 |
+
# For backward compatibility with older versions of transformers
|
| 493 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 494 |
+
model.enable_input_require_grads()
|
| 495 |
+
else:
|
| 496 |
+
|
| 497 |
+
def make_inputs_require_grad(module, input, output):
|
| 498 |
+
output.requires_grad_(True)
|
| 499 |
+
|
| 500 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 501 |
+
|
| 502 |
+
# get peft model with the given config
|
| 503 |
+
model = model
|
| 504 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 505 |
+
peft_module_casting_to_bf16(model)
|
| 506 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 507 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 508 |
+
|
| 509 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 510 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 511 |
+
# fail or completely fail.
|
| 512 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 513 |
+
# For backward compatibility with older versions of transformers
|
| 514 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 515 |
+
model.enable_input_require_grads()
|
| 516 |
+
else:
|
| 517 |
+
|
| 518 |
+
def make_inputs_require_grad(module, input, output):
|
| 519 |
+
output.requires_grad_(True)
|
| 520 |
+
|
| 521 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 522 |
+
|
| 523 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 524 |
+
raise ValueError(
|
| 525 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 526 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if model is not None:
|
| 530 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 531 |
+
elif args.is_encoder_decoder is None:
|
| 532 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 533 |
+
else:
|
| 534 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 535 |
+
|
| 536 |
+
if self.is_encoder_decoder:
|
| 537 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 538 |
+
self.pad_token_id = model.config.pad_token_id
|
| 539 |
+
|
| 540 |
+
if processing_class is None:
|
| 541 |
+
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
| 542 |
+
if args.max_length is None:
|
| 543 |
+
warnings.warn(
|
| 544 |
+
"`max_length` is not set in the CPOConfig's init"
|
| 545 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 546 |
+
UserWarning,
|
| 547 |
+
)
|
| 548 |
+
max_length = 512
|
| 549 |
+
else:
|
| 550 |
+
max_length = args.max_length
|
| 551 |
+
if args.max_prompt_length is None:
|
| 552 |
+
warnings.warn(
|
| 553 |
+
"`max_prompt_length` is not set in the CPOConfig's init"
|
| 554 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 555 |
+
UserWarning,
|
| 556 |
+
)
|
| 557 |
+
max_prompt_length = 128
|
| 558 |
+
else:
|
| 559 |
+
max_prompt_length = args.max_prompt_length
|
| 560 |
+
|
| 561 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 562 |
+
warnings.warn(
|
| 563 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
| 564 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 565 |
+
UserWarning,
|
| 566 |
+
)
|
| 567 |
+
max_completion_length = 128
|
| 568 |
+
else:
|
| 569 |
+
max_completion_length = args.max_completion_length
|
| 570 |
+
|
| 571 |
+
if data_collator is None:
|
| 572 |
+
data_collator = DPODataCollatorWithPadding(
|
| 573 |
+
pad_token_id=processing_class.pad_token_id,
|
| 574 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 575 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if args.remove_unused_columns:
|
| 579 |
+
args.remove_unused_columns = False
|
| 580 |
+
# warn users
|
| 581 |
+
warnings.warn(
|
| 582 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 583 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 584 |
+
UserWarning,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
self.use_dpo_data_collator = True
|
| 588 |
+
else:
|
| 589 |
+
self.use_dpo_data_collator = False
|
| 590 |
+
|
| 591 |
+
# Disable dropout in the model
|
| 592 |
+
if args.disable_dropout:
|
| 593 |
+
disable_dropout_in_model(model)
|
| 594 |
+
|
| 595 |
+
self.max_length = max_length
|
| 596 |
+
self.generate_during_eval = args.generate_during_eval
|
| 597 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 598 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 599 |
+
self.max_prompt_length = max_prompt_length
|
| 600 |
+
self.truncation_mode = args.truncation_mode
|
| 601 |
+
self.max_completion_length = max_completion_length
|
| 602 |
+
self.processing_class = processing_class
|
| 603 |
+
|
| 604 |
+
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
| 605 |
+
warnings.warn(
|
| 606 |
+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
| 607 |
+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
| 608 |
+
UserWarning,
|
| 609 |
+
)
|
| 610 |
+
if args.loss_type == "kto_pair":
|
| 611 |
+
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
| 612 |
+
|
| 613 |
+
self.beta = args.beta
|
| 614 |
+
self.label_smoothing = args.label_smoothing
|
| 615 |
+
self.loss_type = args.loss_type
|
| 616 |
+
self.cpo_alpha = args.cpo_alpha
|
| 617 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 618 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 619 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 620 |
+
warnings.warn(
|
| 621 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 622 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 623 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 624 |
+
"loss.",
|
| 625 |
+
UserWarning,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if args.loss_type == "simpo":
|
| 629 |
+
self.simpo_gamma = args.simpo_gamma
|
| 630 |
+
|
| 631 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 632 |
+
|
| 633 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 634 |
+
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
| 635 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 636 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 637 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 638 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 639 |
+
# that the warning has already been issued.
|
| 640 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 641 |
+
|
| 642 |
+
# Compute that only on the main process for faster data processing.
|
| 643 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 644 |
+
with PartialState().local_main_process_first():
|
| 645 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
| 646 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 647 |
+
train_dataset = train_dataset.map(
|
| 648 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 649 |
+
)
|
| 650 |
+
if eval_dataset is not None:
|
| 651 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 652 |
+
eval_dataset = eval_dataset.map(
|
| 653 |
+
maybe_apply_chat_template,
|
| 654 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 655 |
+
num_proc=args.dataset_num_proc,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# tokenize the dataset
|
| 659 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 660 |
+
if eval_dataset is not None:
|
| 661 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 662 |
+
|
| 663 |
+
super().__init__(
|
| 664 |
+
model=model,
|
| 665 |
+
args=args,
|
| 666 |
+
data_collator=data_collator,
|
| 667 |
+
train_dataset=train_dataset,
|
| 668 |
+
eval_dataset=eval_dataset,
|
| 669 |
+
processing_class=processing_class,
|
| 670 |
+
model_init=model_init,
|
| 671 |
+
compute_metrics=compute_metrics,
|
| 672 |
+
callbacks=callbacks,
|
| 673 |
+
optimizers=optimizers,
|
| 674 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 678 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 679 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 680 |
+
self.model_accepts_loss_kwargs = False
|
| 681 |
+
|
| 682 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 683 |
+
if hasattr(self.model, "add_model_tags"):
|
| 684 |
+
self.model.add_model_tags(self._tag_names)
|
| 685 |
+
|
| 686 |
+
if not hasattr(self, "accelerator"):
|
| 687 |
+
raise AttributeError(
|
| 688 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
def build_tokenized_answer(self, prompt, answer):
|
| 692 |
+
"""
|
| 693 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
| 694 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
| 695 |
+
Reference:
|
| 696 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 700 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 701 |
+
|
| 702 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 703 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 704 |
+
|
| 705 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 706 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 707 |
+
|
| 708 |
+
# Prepare input tokens for token by token comparison
|
| 709 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 710 |
+
|
| 711 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
| 712 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 713 |
+
|
| 714 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 715 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 716 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 717 |
+
# vs when done as prompt+answer.
|
| 718 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
| 719 |
+
|
| 720 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 721 |
+
# last token has changed due to merging.
|
| 722 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 723 |
+
response_token_ids_start_idx -= 1
|
| 724 |
+
|
| 725 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 726 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 727 |
+
|
| 728 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 729 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 730 |
+
|
| 731 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 732 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 733 |
+
|
| 734 |
+
return dict(
|
| 735 |
+
prompt_input_ids=prompt_input_ids,
|
| 736 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 737 |
+
input_ids=answer_input_ids,
|
| 738 |
+
attention_mask=answer_attention_mask,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 742 |
+
"""Tokenize a single row from a CPO specific dataset.
|
| 743 |
+
|
| 744 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
| 745 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
| 746 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
| 747 |
+
|
| 748 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
| 749 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
| 750 |
+
label_pad_token_id for the prompt tokens.
|
| 751 |
+
"""
|
| 752 |
+
batch = {}
|
| 753 |
+
prompt = feature["prompt"]
|
| 754 |
+
chosen = feature["chosen"]
|
| 755 |
+
rejected = feature["rejected"]
|
| 756 |
+
|
| 757 |
+
if not self.is_encoder_decoder:
|
| 758 |
+
# Check issues below for more details
|
| 759 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 760 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 761 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 762 |
+
|
| 763 |
+
if not isinstance(prompt, str):
|
| 764 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 765 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 766 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 767 |
+
|
| 768 |
+
if not isinstance(chosen, str):
|
| 769 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 770 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 771 |
+
|
| 772 |
+
if not isinstance(rejected, str):
|
| 773 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 774 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 775 |
+
|
| 776 |
+
# Last prompt token might get merged by tokenizer and
|
| 777 |
+
# it should not be included for generation if that happens
|
| 778 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 779 |
+
|
| 780 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 781 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 782 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 783 |
+
|
| 784 |
+
for k, v in prompt_tokens.items():
|
| 785 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 786 |
+
|
| 787 |
+
# Make sure prompts only have one different token at most an
|
| 788 |
+
# and length only differs by 1 at most
|
| 789 |
+
num_diff_tokens = sum(
|
| 790 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
| 791 |
+
)
|
| 792 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 793 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 794 |
+
raise ValueError(
|
| 795 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 796 |
+
"last token due to tokenizer merge ops."
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 800 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 801 |
+
self.processing_class.bos_token_id,
|
| 802 |
+
prompt_len_input_ids,
|
| 803 |
+
prompt_tokens,
|
| 804 |
+
chosen_prompt_len_input_ids,
|
| 805 |
+
chosen_tokens,
|
| 806 |
+
rejected_prompt_len_input_ids,
|
| 807 |
+
rejected_tokens,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
| 811 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 812 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 816 |
+
|
| 817 |
+
# if combined sequence is too long, truncate the prompt
|
| 818 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 819 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 820 |
+
if self.truncation_mode == "keep_start":
|
| 821 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 822 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 823 |
+
elif self.truncation_mode == "keep_end":
|
| 824 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 825 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 826 |
+
else:
|
| 827 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 828 |
+
|
| 829 |
+
# if that's still too long, truncate the response
|
| 830 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 831 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 832 |
+
for k in ["input_ids", "attention_mask"]:
|
| 833 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 834 |
+
|
| 835 |
+
# Create labels
|
| 836 |
+
chosen_sequence_tokens = {
|
| 837 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 838 |
+
}
|
| 839 |
+
rejected_sequence_tokens = {
|
| 840 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 841 |
+
}
|
| 842 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 843 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 844 |
+
self.label_pad_token_id
|
| 845 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
| 846 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 847 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 848 |
+
self.label_pad_token_id
|
| 849 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
| 850 |
+
|
| 851 |
+
for k, toks in {
|
| 852 |
+
"chosen_": chosen_sequence_tokens,
|
| 853 |
+
"rejected_": rejected_sequence_tokens,
|
| 854 |
+
"": prompt_tokens,
|
| 855 |
+
}.items():
|
| 856 |
+
for type_key, tokens in toks.items():
|
| 857 |
+
if type_key == "token_type_ids":
|
| 858 |
+
continue
|
| 859 |
+
batch[f"{k}{type_key}"] = tokens
|
| 860 |
+
|
| 861 |
+
else:
|
| 862 |
+
chosen_tokens = self.processing_class(
|
| 863 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 864 |
+
)
|
| 865 |
+
rejected_tokens = self.processing_class(
|
| 866 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 867 |
+
)
|
| 868 |
+
prompt_tokens = self.processing_class(
|
| 869 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 873 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 874 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 875 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 876 |
+
|
| 877 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 878 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 879 |
+
labels=torch.tensor(batch["rejected_labels"])
|
| 880 |
+
)
|
| 881 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 882 |
+
labels=torch.tensor(batch["chosen_labels"])
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
return batch
|
| 886 |
+
|
| 887 |
+
@staticmethod
|
| 888 |
+
def concatenated_inputs(
|
| 889 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 890 |
+
is_encoder_decoder: bool = False,
|
| 891 |
+
label_pad_token_id: int = -100,
|
| 892 |
+
padding_value: int = 0,
|
| 893 |
+
device: Optional[torch.device] = None,
|
| 894 |
+
) -> dict[str, torch.LongTensor]:
|
| 895 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 896 |
+
|
| 897 |
+
Args:
|
| 898 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
| 899 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 900 |
+
label_pad_token_id: The label pad token id.
|
| 901 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
| 902 |
+
device: The device for the concatenated inputs.
|
| 903 |
+
|
| 904 |
+
Returns:
|
| 905 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 906 |
+
"""
|
| 907 |
+
concatenated_batch = {}
|
| 908 |
+
|
| 909 |
+
if is_encoder_decoder:
|
| 910 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 911 |
+
else:
|
| 912 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 913 |
+
|
| 914 |
+
for k in batch:
|
| 915 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 916 |
+
if "labels" in k or is_encoder_decoder:
|
| 917 |
+
pad_value = label_pad_token_id
|
| 918 |
+
elif k.endswith("_input_ids"):
|
| 919 |
+
pad_value = padding_value
|
| 920 |
+
elif k.endswith("_attention_mask"):
|
| 921 |
+
pad_value = 0
|
| 922 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
| 923 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 924 |
+
for k in batch:
|
| 925 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 926 |
+
if "labels" in k or is_encoder_decoder:
|
| 927 |
+
pad_value = label_pad_token_id
|
| 928 |
+
elif k.endswith("_input_ids"):
|
| 929 |
+
pad_value = padding_value
|
| 930 |
+
elif k.endswith("_attention_mask"):
|
| 931 |
+
pad_value = 0
|
| 932 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
| 933 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
| 934 |
+
(
|
| 935 |
+
concatenated_batch[concatenated_key],
|
| 936 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 937 |
+
),
|
| 938 |
+
dim=0,
|
| 939 |
+
).to(device=device)
|
| 940 |
+
|
| 941 |
+
if is_encoder_decoder:
|
| 942 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 943 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
| 944 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
return concatenated_batch
|
| 948 |
+
|
| 949 |
+
def cpo_loss(
|
| 950 |
+
self,
|
| 951 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 952 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 953 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 954 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
| 955 |
+
|
| 956 |
+
Args:
|
| 957 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 958 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 959 |
+
|
| 960 |
+
Returns:
|
| 961 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
| 962 |
+
The losses tensor contains the CPO loss for each example in the batch.
|
| 963 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 964 |
+
"""
|
| 965 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
| 966 |
+
|
| 967 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
| 968 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
| 969 |
+
# calculates a conservative CPO loss.
|
| 970 |
+
|
| 971 |
+
if self.loss_type == "simpo":
|
| 972 |
+
gamma_logratios = self.simpo_gamma / self.beta
|
| 973 |
+
logits = logits - gamma_logratios
|
| 974 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 975 |
+
losses = (
|
| 976 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 977 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 978 |
+
)
|
| 979 |
+
elif self.loss_type == "sigmoid":
|
| 980 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 981 |
+
losses = (
|
| 982 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 983 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 984 |
+
)
|
| 985 |
+
elif self.loss_type == "hinge":
|
| 986 |
+
losses = torch.relu(1 - self.beta * logits)
|
| 987 |
+
elif self.loss_type == "ipo":
|
| 988 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
| 989 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 990 |
+
else:
|
| 991 |
+
raise ValueError(
|
| 992 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 996 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 997 |
+
|
| 998 |
+
return losses, chosen_rewards, rejected_rewards
|
| 999 |
+
|
| 1000 |
+
@staticmethod
|
| 1001 |
+
def get_batch_logps(
|
| 1002 |
+
logits: torch.FloatTensor,
|
| 1003 |
+
labels: torch.LongTensor,
|
| 1004 |
+
average_log_prob: bool = False,
|
| 1005 |
+
label_pad_token_id: int = -100,
|
| 1006 |
+
is_encoder_decoder: bool = False,
|
| 1007 |
+
) -> torch.FloatTensor:
|
| 1008 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1009 |
+
|
| 1010 |
+
Args:
|
| 1011 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1012 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1013 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1014 |
+
label_pad_token_id: The label pad token id.
|
| 1015 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 1016 |
+
|
| 1017 |
+
Returns:
|
| 1018 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1019 |
+
"""
|
| 1020 |
+
if logits.shape[:-1] != labels.shape:
|
| 1021 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1022 |
+
|
| 1023 |
+
if not is_encoder_decoder:
|
| 1024 |
+
labels = labels[:, 1:].clone()
|
| 1025 |
+
logits = logits[:, :-1, :]
|
| 1026 |
+
loss_mask = labels != label_pad_token_id
|
| 1027 |
+
|
| 1028 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1029 |
+
labels[labels == label_pad_token_id] = 0
|
| 1030 |
+
|
| 1031 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1032 |
+
|
| 1033 |
+
if average_log_prob:
|
| 1034 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1035 |
+
else:
|
| 1036 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1037 |
+
|
| 1038 |
+
def concatenated_forward(
|
| 1039 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1040 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1041 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1042 |
+
|
| 1043 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1044 |
+
"""
|
| 1045 |
+
concatenated_batch = self.concatenated_inputs(
|
| 1046 |
+
batch,
|
| 1047 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1048 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1049 |
+
padding_value=self.padding_value,
|
| 1050 |
+
device=self.accelerator.device,
|
| 1051 |
+
)
|
| 1052 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
| 1053 |
+
|
| 1054 |
+
model_kwargs = (
|
| 1055 |
+
{
|
| 1056 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1057 |
+
}
|
| 1058 |
+
if self.is_encoder_decoder
|
| 1059 |
+
else {}
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
if self.aux_loss_enabled:
|
| 1063 |
+
model_kwargs["output_router_logits"] = True
|
| 1064 |
+
|
| 1065 |
+
outputs = model(
|
| 1066 |
+
concatenated_batch["concatenated_input_ids"],
|
| 1067 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1068 |
+
use_cache=False,
|
| 1069 |
+
**model_kwargs,
|
| 1070 |
+
)
|
| 1071 |
+
all_logits = outputs.logits
|
| 1072 |
+
|
| 1073 |
+
def cross_entropy_loss(logits, labels):
|
| 1074 |
+
if not self.is_encoder_decoder:
|
| 1075 |
+
# Shift so that tokens < n predict n
|
| 1076 |
+
logits = logits[..., :-1, :].contiguous()
|
| 1077 |
+
labels = labels[..., 1:].contiguous()
|
| 1078 |
+
# Flatten the tokens
|
| 1079 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1080 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 1081 |
+
labels = labels.view(-1)
|
| 1082 |
+
# Enable model parallelism
|
| 1083 |
+
labels = labels.to(logits.device)
|
| 1084 |
+
loss = loss_fct(logits, labels)
|
| 1085 |
+
return loss
|
| 1086 |
+
|
| 1087 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1088 |
+
|
| 1089 |
+
if self.cpo_alpha == 0:
|
| 1090 |
+
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
| 1091 |
+
else:
|
| 1092 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1093 |
+
|
| 1094 |
+
all_logps = self.get_batch_logps(
|
| 1095 |
+
all_logits,
|
| 1096 |
+
concatenated_batch["concatenated_labels"],
|
| 1097 |
+
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
| 1098 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1099 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
chosen_logps = all_logps[:len_chosen]
|
| 1103 |
+
rejected_logps = all_logps[len_chosen:]
|
| 1104 |
+
|
| 1105 |
+
chosen_logits = all_logits[:len_chosen]
|
| 1106 |
+
rejected_logits = all_logits[len_chosen:]
|
| 1107 |
+
|
| 1108 |
+
if self.aux_loss_enabled:
|
| 1109 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
| 1110 |
+
|
| 1111 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
| 1112 |
+
|
| 1113 |
+
def get_batch_loss_metrics(
|
| 1114 |
+
self,
|
| 1115 |
+
model,
|
| 1116 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1117 |
+
train_eval: Literal["train", "eval"] = "train",
|
| 1118 |
+
):
|
| 1119 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1120 |
+
metrics = {}
|
| 1121 |
+
|
| 1122 |
+
forward_output = self.concatenated_forward(model, batch)
|
| 1123 |
+
(
|
| 1124 |
+
policy_chosen_logps,
|
| 1125 |
+
policy_rejected_logps,
|
| 1126 |
+
policy_chosen_logits,
|
| 1127 |
+
policy_rejected_logits,
|
| 1128 |
+
policy_nll_loss,
|
| 1129 |
+
) = forward_output[:5]
|
| 1130 |
+
if self.aux_loss_enabled:
|
| 1131 |
+
aux_loss = forward_output[5]
|
| 1132 |
+
|
| 1133 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
| 1134 |
+
policy_chosen_logps,
|
| 1135 |
+
policy_rejected_logps,
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
| 1139 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1140 |
+
|
| 1141 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1142 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
| 1143 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
| 1144 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
| 1145 |
+
metrics[f"{prefix}rewards/margins"] = (
|
| 1146 |
+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
| 1147 |
+
)
|
| 1148 |
+
metrics[f"{prefix}logps/rejected"] = (
|
| 1149 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
| 1150 |
+
)
|
| 1151 |
+
metrics[f"{prefix}logps/chosen"] = (
|
| 1152 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
| 1153 |
+
)
|
| 1154 |
+
metrics[f"{prefix}logits/rejected"] = (
|
| 1155 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
|
| 1156 |
+
)
|
| 1157 |
+
metrics[f"{prefix}logits/chosen"] = (
|
| 1158 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
|
| 1159 |
+
)
|
| 1160 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
| 1161 |
+
|
| 1162 |
+
if self.aux_loss_enabled:
|
| 1163 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1164 |
+
|
| 1165 |
+
return loss, metrics
|
| 1166 |
+
|
| 1167 |
+
def compute_loss(
|
| 1168 |
+
self,
|
| 1169 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1170 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1171 |
+
return_outputs=False,
|
| 1172 |
+
num_items_in_batch=None,
|
| 1173 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1174 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1175 |
+
|
| 1176 |
+
with compute_loss_context_manager:
|
| 1177 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1178 |
+
|
| 1179 |
+
# force log the metrics
|
| 1180 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1181 |
+
|
| 1182 |
+
if return_outputs:
|
| 1183 |
+
return (loss, metrics)
|
| 1184 |
+
return loss
|
| 1185 |
+
|
| 1186 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1187 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1188 |
+
|
| 1189 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1190 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1191 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1192 |
+
|
| 1193 |
+
with generate_context_manager:
|
| 1194 |
+
policy_output = model.generate(
|
| 1195 |
+
input_ids=batch["prompt_input_ids"],
|
| 1196 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1197 |
+
max_length=self.max_length,
|
| 1198 |
+
do_sample=True,
|
| 1199 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1203 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1204 |
+
|
| 1205 |
+
return policy_output_decoded
|
| 1206 |
+
|
| 1207 |
+
def prediction_step(
|
| 1208 |
+
self,
|
| 1209 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1210 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1211 |
+
prediction_loss_only: bool,
|
| 1212 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1213 |
+
):
|
| 1214 |
+
if ignore_keys is None:
|
| 1215 |
+
if hasattr(model, "config"):
|
| 1216 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1217 |
+
else:
|
| 1218 |
+
ignore_keys = []
|
| 1219 |
+
|
| 1220 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1221 |
+
|
| 1222 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1223 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1224 |
+
|
| 1225 |
+
# force log the metrics
|
| 1226 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1227 |
+
|
| 1228 |
+
if prediction_loss_only:
|
| 1229 |
+
return (loss.detach(), None, None)
|
| 1230 |
+
|
| 1231 |
+
# logits for the chosen and rejected samples from model
|
| 1232 |
+
logits_dict = {
|
| 1233 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1234 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1235 |
+
}
|
| 1236 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
| 1237 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
| 1238 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1239 |
+
|
| 1240 |
+
return (loss.detach(), logits, labels)
|
| 1241 |
+
|
| 1242 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1243 |
+
for key, value in metrics.items():
|
| 1244 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1245 |
+
|
| 1246 |
+
def evaluation_loop(
|
| 1247 |
+
self,
|
| 1248 |
+
dataloader: DataLoader,
|
| 1249 |
+
description: str,
|
| 1250 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1251 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1252 |
+
metric_key_prefix: str = "eval",
|
| 1253 |
+
) -> EvalLoopOutput:
|
| 1254 |
+
"""
|
| 1255 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1256 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1257 |
+
|
| 1258 |
+
Works both with or without labels.
|
| 1259 |
+
"""
|
| 1260 |
+
|
| 1261 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1262 |
+
if self.generate_during_eval:
|
| 1263 |
+
# Generate random indices within the range of the total number of samples
|
| 1264 |
+
num_samples = len(dataloader.dataset)
|
| 1265 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1266 |
+
|
| 1267 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1268 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1269 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1270 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1271 |
+
|
| 1272 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1273 |
+
|
| 1274 |
+
table = pd.DataFrame(
|
| 1275 |
+
columns=["Prompt", "Policy"],
|
| 1276 |
+
data=[
|
| 1277 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1278 |
+
],
|
| 1279 |
+
)
|
| 1280 |
+
if "wandb" in self.args.report_to:
|
| 1281 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1282 |
+
|
| 1283 |
+
if "comet_ml" in self.args.report_to:
|
| 1284 |
+
log_table_to_comet_experiment(
|
| 1285 |
+
name="game_log.csv",
|
| 1286 |
+
table=table,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
# Base evaluation
|
| 1290 |
+
initial_output = super().evaluation_loop(
|
| 1291 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
return initial_output
|
| 1295 |
+
|
| 1296 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1297 |
+
"""
|
| 1298 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1299 |
+
|
| 1300 |
+
Args:
|
| 1301 |
+
logs (`dict[str, float]`):
|
| 1302 |
+
The values to log.
|
| 1303 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1304 |
+
Start time of the training.
|
| 1305 |
+
"""
|
| 1306 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1307 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1308 |
+
# Add averaged stored metrics to logs
|
| 1309 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1310 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 1311 |
+
del self._stored_metrics[train_eval]
|
| 1312 |
+
|
| 1313 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1314 |
+
return super().log(logs, start_time)
|
| 1315 |
+
else: # transformers<=4.46
|
| 1316 |
+
return super().log(logs)
|
| 1317 |
+
|
| 1318 |
+
def _shift_right(self, input_ids):
|
| 1319 |
+
if self.decoder_start_token_id is None:
|
| 1320 |
+
raise ValueError(
|
| 1321 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
# shift inputs to the right
|
| 1325 |
+
if is_torch_fx_proxy(input_ids):
|
| 1326 |
+
# Item assignment is not supported natively for proxies.
|
| 1327 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1328 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1329 |
+
else:
|
| 1330 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1331 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1332 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1333 |
+
|
| 1334 |
+
if self.pad_token_id is None:
|
| 1335 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1336 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1337 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1338 |
+
|
| 1339 |
+
return shifted_input_ids
|
| 1340 |
+
|
| 1341 |
+
def create_model_card(
|
| 1342 |
+
self,
|
| 1343 |
+
model_name: Optional[str] = None,
|
| 1344 |
+
dataset_name: Optional[str] = None,
|
| 1345 |
+
tags: Union[str, list[str], None] = None,
|
| 1346 |
+
):
|
| 1347 |
+
"""
|
| 1348 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1349 |
+
|
| 1350 |
+
Args:
|
| 1351 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1352 |
+
Name of the model.
|
| 1353 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1354 |
+
Name of the dataset used for training.
|
| 1355 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1356 |
+
Tags to be associated with the model card.
|
| 1357 |
+
"""
|
| 1358 |
+
if not self.is_world_process_zero():
|
| 1359 |
+
return
|
| 1360 |
+
|
| 1361 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1362 |
+
base_model = self.model.config._name_or_path
|
| 1363 |
+
else:
|
| 1364 |
+
base_model = None
|
| 1365 |
+
|
| 1366 |
+
tags = tags or []
|
| 1367 |
+
if isinstance(tags, str):
|
| 1368 |
+
tags = [tags]
|
| 1369 |
+
|
| 1370 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1371 |
+
tags.append("unsloth")
|
| 1372 |
+
|
| 1373 |
+
citation = textwrap.dedent("""\
|
| 1374 |
+
@inproceedings{xu2024contrastive,
|
| 1375 |
+
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
| 1376 |
+
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
| 1377 |
+
year = 2024,
|
| 1378 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 1379 |
+
publisher = {OpenReview.net},
|
| 1380 |
+
url = {https://openreview.net/forum?id=51iwkioZpn}
|
| 1381 |
+
}""")
|
| 1382 |
+
|
| 1383 |
+
model_card = generate_model_card(
|
| 1384 |
+
base_model=base_model,
|
| 1385 |
+
model_name=model_name,
|
| 1386 |
+
hub_model_id=self.hub_model_id,
|
| 1387 |
+
dataset_name=dataset_name,
|
| 1388 |
+
tags=tags,
|
| 1389 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1390 |
+
comet_url=get_comet_experiment_url(),
|
| 1391 |
+
trainer_name="CPO",
|
| 1392 |
+
trainer_citation=citation,
|
| 1393 |
+
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
| 1394 |
+
paper_id="2401.08417",
|
| 1395 |
+
)
|
| 1396 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1397 |
+
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
| 1398 |
+
"""
|
| 1399 |
+
|
| 1400 |
+
Initialize CPOTrainer.
|
| 1401 |
+
|
| 1402 |
+
Args:
|
| 1403 |
+
model (`transformers.PreTrainedModel`):
|
| 1404 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1405 |
+
args (`CPOConfig`):
|
| 1406 |
+
The CPO config arguments to use for training.
|
| 1407 |
+
data_collator (`transformers.DataCollator`):
|
| 1408 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1409 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1410 |
+
train_dataset (`datasets.Dataset`):
|
| 1411 |
+
The dataset to use for training.
|
| 1412 |
+
eval_dataset (`datasets.Dataset`):
|
| 1413 |
+
The dataset to use for evaluation.
|
| 1414 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1415 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1416 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1417 |
+
reuse the fine-tuned model.
|
| 1418 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1419 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1420 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1421 |
+
The callbacks to use for training.
|
| 1422 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1423 |
+
The optimizer and scheduler to use for training.
|
| 1424 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1425 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1426 |
+
peft_config (`dict`, defaults to `None`):
|
| 1427 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1428 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1429 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1430 |
+
a dictionary string to metric values.
|
| 1431 |
+
|
| 1432 |
+
"""
|
| 1433 |
+
def __init__(
|
| 1434 |
+
self,
|
| 1435 |
+
model = None,
|
| 1436 |
+
args = None,
|
| 1437 |
+
data_collator = None,
|
| 1438 |
+
train_dataset = None,
|
| 1439 |
+
eval_dataset = None,
|
| 1440 |
+
processing_class = None,
|
| 1441 |
+
model_init = None,
|
| 1442 |
+
callbacks = None,
|
| 1443 |
+
preprocess_logits_for_metrics = None,
|
| 1444 |
+
peft_config = None,
|
| 1445 |
+
compute_metrics = None,
|
| 1446 |
+
**kwargs
|
| 1447 |
+
):
|
| 1448 |
+
if args is None: args = UnslothCPOConfig()
|
| 1449 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1450 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1451 |
+
force_float32 = False
|
| 1452 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1453 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1454 |
+
force_float32 = True
|
| 1455 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1456 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1457 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1458 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1459 |
+
dtype = _get_dtype(dtype)
|
| 1460 |
+
float16 = dtype == torch.float16
|
| 1461 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1462 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1463 |
+
if force_float32:
|
| 1464 |
+
args.fp16 = False
|
| 1465 |
+
args.bf16 = False
|
| 1466 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1467 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1468 |
+
args.fp16 = float16
|
| 1469 |
+
args.bf16 = not float16
|
| 1470 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1471 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1472 |
+
args.eval_strategy = 'steps'
|
| 1473 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1474 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1475 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1476 |
+
from transformers import __version__ as transformers_version
|
| 1477 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1478 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1479 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1480 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1481 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1482 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1483 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1484 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1485 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1486 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1487 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1488 |
+
if force_float32:
|
| 1489 |
+
args.bf16_full_eval = False
|
| 1490 |
+
args.fp16_full_eval = False
|
| 1491 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1492 |
+
args.bf16_full_eval = True
|
| 1493 |
+
args.fp16_full_eval = False
|
| 1494 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1495 |
+
args.bf16_full_eval = args.bf16
|
| 1496 |
+
args.fp16_full_eval = args.fp16
|
| 1497 |
+
_output_logits = False
|
| 1498 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1499 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1500 |
+
if _output_logits:
|
| 1501 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1502 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1503 |
+
pass
|
| 1504 |
+
else:
|
| 1505 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1506 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1507 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1508 |
+
max_seq_length = model.max_seq_length
|
| 1509 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1510 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1511 |
+
model.for_training()
|
| 1512 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1513 |
+
if 'processing_class' in locals():
|
| 1514 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1515 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1516 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1517 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1518 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1519 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1520 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1521 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1522 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1523 |
+
else:
|
| 1524 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1525 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1526 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1527 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1528 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1529 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1530 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1531 |
+
else:
|
| 1532 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1533 |
+
other_metrics = []
|
| 1534 |
+
|
| 1535 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1536 |
+
PatchRLStatistics('cpo_trainer', other_metrics)
|
| 1537 |
+
|
| 1538 |
+
super().__init__(
|
| 1539 |
+
model = model,
|
| 1540 |
+
args = args,
|
| 1541 |
+
data_collator = data_collator,
|
| 1542 |
+
train_dataset = train_dataset,
|
| 1543 |
+
eval_dataset = eval_dataset,
|
| 1544 |
+
processing_class = processing_class,
|
| 1545 |
+
model_init = model_init,
|
| 1546 |
+
callbacks = callbacks,
|
| 1547 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1548 |
+
peft_config = peft_config,
|
| 1549 |
+
compute_metrics = compute_metrics,**kwargs)
|
| 1550 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1551 |
+
self.neftune_hook_handle.remove()
|
| 1552 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1553 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1554 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1555 |
+
pass
|
| 1556 |
+
|
| 1557 |
+
pass
|
unsloth_compiled_cache/UnslothDDPOTrainer.py
ADDED
|
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothDDPOConfig(DDPOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`DDPOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
| 54 |
+
Name of this experiment (by default is the file name without the extension name).
|
| 55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
| 56 |
+
Name of this run.
|
| 57 |
+
seed (`int`, *optional*, defaults to `0`):
|
| 58 |
+
Random seed.
|
| 59 |
+
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
| 60 |
+
Log with either 'wandb' or 'tensorboard', check
|
| 61 |
+
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
| 62 |
+
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 63 |
+
Keyword arguments for the tracker (e.g. wandb_project).
|
| 64 |
+
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 65 |
+
Keyword arguments for the accelerator.
|
| 66 |
+
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 67 |
+
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
| 68 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
| 69 |
+
Name of project to use for tracking.
|
| 70 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
| 71 |
+
Top-level logging directory for checkpoint saving.
|
| 72 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
| 73 |
+
Number of epochs to train.
|
| 74 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
| 75 |
+
Number of epochs between saving model checkpoints.
|
| 76 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
| 77 |
+
Number of checkpoints to keep before overwriting old ones.
|
| 78 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
| 79 |
+
Mixed precision training.
|
| 80 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Allow `tf32` on Ampere GPUs.
|
| 82 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
| 83 |
+
Resume training from a checkpoint.
|
| 84 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
| 85 |
+
Number of sampler inference steps.
|
| 86 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
| 87 |
+
Eta parameter for the DDIM sampler.
|
| 88 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 89 |
+
Classifier-free guidance weight.
|
| 90 |
+
sample_batch_size (`int`, *optional*, defaults to `1`):
|
| 91 |
+
Batch size (per GPU) to use for sampling.
|
| 92 |
+
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
| 93 |
+
Number of batches to sample per epoch.
|
| 94 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
| 95 |
+
Batch size (per GPU) to use for training.
|
| 96 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
| 97 |
+
Use 8bit Adam optimizer from bitsandbytes.
|
| 98 |
+
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
| 99 |
+
Learning rate.
|
| 100 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
| 101 |
+
Adam beta1.
|
| 102 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
| 103 |
+
Adam beta2.
|
| 104 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
| 105 |
+
Adam weight decay.
|
| 106 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
| 107 |
+
Adam epsilon.
|
| 108 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
| 109 |
+
Number of gradient accumulation steps.
|
| 110 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
| 111 |
+
Maximum gradient norm for gradient clipping.
|
| 112 |
+
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
| 113 |
+
Number of inner epochs per outer epoch.
|
| 114 |
+
train_cfg (`bool`, *optional*, defaults to `True`):
|
| 115 |
+
Whether to use classifier-free guidance during training.
|
| 116 |
+
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
| 117 |
+
Clip advantages to the range.
|
| 118 |
+
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
| 119 |
+
PPO clip range.
|
| 120 |
+
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
| 121 |
+
Fraction of timesteps to train on.
|
| 122 |
+
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
| 123 |
+
Whether to track statistics for each prompt separately.
|
| 124 |
+
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
| 125 |
+
Number of reward values to store in the buffer for each prompt.
|
| 126 |
+
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
| 127 |
+
Minimum number of reward values to store in the buffer.
|
| 128 |
+
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
| 129 |
+
Whether to compute rewards asynchronously.
|
| 130 |
+
max_workers (`int`, *optional*, defaults to `2`):
|
| 131 |
+
Maximum number of workers to use for async reward computation.
|
| 132 |
+
negative_prompts (`str`, *optional*, defaults to `""`):
|
| 133 |
+
Comma-separated list of prompts to use as negative examples.
|
| 134 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 135 |
+
Whether to push the final model checkpoint to the Hub.
|
| 136 |
+
|
| 137 |
+
"""
|
| 138 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 139 |
+
default = None,
|
| 140 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 141 |
+
)
|
| 142 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 143 |
+
default = -1,
|
| 144 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 145 |
+
)
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
exp_name = 'main',
|
| 149 |
+
run_name = '',
|
| 150 |
+
seed = 3407,
|
| 151 |
+
log_with = None,
|
| 152 |
+
tracker_project_name = 'trl',
|
| 153 |
+
logdir = 'logs',
|
| 154 |
+
num_epochs = 100,
|
| 155 |
+
save_freq = 1,
|
| 156 |
+
num_checkpoint_limit = 5,
|
| 157 |
+
mixed_precision = 'fp16',
|
| 158 |
+
allow_tf32 = True,
|
| 159 |
+
resume_from = '',
|
| 160 |
+
sample_num_steps = 50,
|
| 161 |
+
sample_eta = 1.0,
|
| 162 |
+
sample_guidance_scale = 5.0,
|
| 163 |
+
sample_batch_size = 1,
|
| 164 |
+
sample_num_batches_per_epoch = 2,
|
| 165 |
+
train_batch_size = 1,
|
| 166 |
+
train_use_8bit_adam = False,
|
| 167 |
+
train_learning_rate = 5e-05,
|
| 168 |
+
train_adam_beta1 = 0.9,
|
| 169 |
+
train_adam_beta2 = 0.999,
|
| 170 |
+
train_adam_weight_decay = 0.01,
|
| 171 |
+
train_adam_epsilon = 1e-08,
|
| 172 |
+
train_gradient_accumulation_steps = 2,
|
| 173 |
+
train_max_grad_norm = 1.0,
|
| 174 |
+
train_num_inner_epochs = 1,
|
| 175 |
+
train_cfg = True,
|
| 176 |
+
train_adv_clip_max = 5.0,
|
| 177 |
+
train_clip_range = 0.0001,
|
| 178 |
+
train_timestep_fraction = 1.0,
|
| 179 |
+
per_prompt_stat_tracking = False,
|
| 180 |
+
per_prompt_stat_tracking_buffer_size = 16,
|
| 181 |
+
per_prompt_stat_tracking_min_count = 16,
|
| 182 |
+
async_reward_computation = False,
|
| 183 |
+
max_workers = 2,
|
| 184 |
+
negative_prompts = '',
|
| 185 |
+
push_to_hub = False,
|
| 186 |
+
vllm_sampling_params = None,
|
| 187 |
+
unsloth_num_chunks = -1,
|
| 188 |
+
**kwargs,
|
| 189 |
+
):
|
| 190 |
+
|
| 191 |
+
super().__init__(
|
| 192 |
+
exp_name = exp_name,
|
| 193 |
+
run_name = run_name,
|
| 194 |
+
seed = seed,
|
| 195 |
+
log_with = log_with,
|
| 196 |
+
tracker_project_name = tracker_project_name,
|
| 197 |
+
logdir = logdir,
|
| 198 |
+
num_epochs = num_epochs,
|
| 199 |
+
save_freq = save_freq,
|
| 200 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
| 201 |
+
mixed_precision = mixed_precision,
|
| 202 |
+
allow_tf32 = allow_tf32,
|
| 203 |
+
resume_from = resume_from,
|
| 204 |
+
sample_num_steps = sample_num_steps,
|
| 205 |
+
sample_eta = sample_eta,
|
| 206 |
+
sample_guidance_scale = sample_guidance_scale,
|
| 207 |
+
sample_batch_size = sample_batch_size,
|
| 208 |
+
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
| 209 |
+
train_batch_size = train_batch_size,
|
| 210 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
| 211 |
+
train_learning_rate = train_learning_rate,
|
| 212 |
+
train_adam_beta1 = train_adam_beta1,
|
| 213 |
+
train_adam_beta2 = train_adam_beta2,
|
| 214 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
| 215 |
+
train_adam_epsilon = train_adam_epsilon,
|
| 216 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
| 217 |
+
train_max_grad_norm = train_max_grad_norm,
|
| 218 |
+
train_num_inner_epochs = train_num_inner_epochs,
|
| 219 |
+
train_cfg = train_cfg,
|
| 220 |
+
train_adv_clip_max = train_adv_clip_max,
|
| 221 |
+
train_clip_range = train_clip_range,
|
| 222 |
+
train_timestep_fraction = train_timestep_fraction,
|
| 223 |
+
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
| 224 |
+
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
| 225 |
+
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
| 226 |
+
async_reward_computation = async_reward_computation,
|
| 227 |
+
max_workers = max_workers,
|
| 228 |
+
negative_prompts = negative_prompts,
|
| 229 |
+
push_to_hub = push_to_hub,**kwargs)
|
| 230 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 231 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
| 235 |
+
""""""
|
| 236 |
+
|
| 237 |
+
_tag_names = ["trl", "ddpo"]
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
config: DDPOConfig,
|
| 242 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
| 243 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
| 244 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
| 245 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
| 246 |
+
):
|
| 247 |
+
if image_samples_hook is None:
|
| 248 |
+
warn("No image_samples_hook provided; no images will be logged")
|
| 249 |
+
|
| 250 |
+
self.prompt_fn = prompt_function
|
| 251 |
+
self.reward_fn = reward_function
|
| 252 |
+
self.config = config
|
| 253 |
+
self.image_samples_callback = image_samples_hook
|
| 254 |
+
|
| 255 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
| 256 |
+
|
| 257 |
+
if self.config.resume_from:
|
| 258 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
| 259 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
| 260 |
+
# get the most recent checkpoint in this directory
|
| 261 |
+
checkpoints = list(
|
| 262 |
+
filter(
|
| 263 |
+
lambda x: "checkpoint_" in x,
|
| 264 |
+
os.listdir(self.config.resume_from),
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
if len(checkpoints) == 0:
|
| 268 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
| 269 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
| 270 |
+
self.config.resume_from = os.path.join(
|
| 271 |
+
self.config.resume_from,
|
| 272 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
| 276 |
+
|
| 277 |
+
# number of timesteps within each trajectory to train on
|
| 278 |
+
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
| 279 |
+
|
| 280 |
+
self.accelerator = Accelerator(
|
| 281 |
+
log_with=self.config.log_with,
|
| 282 |
+
mixed_precision=self.config.mixed_precision,
|
| 283 |
+
project_config=accelerator_project_config,
|
| 284 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
| 285 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
| 286 |
+
# the total number of optimizer steps to accumulate across.
|
| 287 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
| 288 |
+
**self.config.accelerator_kwargs,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
is_okay, message = self._config_check()
|
| 292 |
+
if not is_okay:
|
| 293 |
+
raise ValueError(message)
|
| 294 |
+
|
| 295 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
| 296 |
+
|
| 297 |
+
if self.accelerator.is_main_process:
|
| 298 |
+
self.accelerator.init_trackers(
|
| 299 |
+
self.config.tracker_project_name,
|
| 300 |
+
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
| 301 |
+
init_kwargs=self.config.tracker_kwargs,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
logger.info(f"\n{config}")
|
| 305 |
+
|
| 306 |
+
set_seed(self.config.seed, device_specific=True)
|
| 307 |
+
|
| 308 |
+
self.sd_pipeline = sd_pipeline
|
| 309 |
+
|
| 310 |
+
self.sd_pipeline.set_progress_bar_config(
|
| 311 |
+
position=1,
|
| 312 |
+
disable=not self.accelerator.is_local_main_process,
|
| 313 |
+
leave=False,
|
| 314 |
+
desc="Timestep",
|
| 315 |
+
dynamic_ncols=True,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 319 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 320 |
+
if self.accelerator.mixed_precision == "fp16":
|
| 321 |
+
inference_dtype = torch.float16
|
| 322 |
+
elif self.accelerator.mixed_precision == "bf16":
|
| 323 |
+
inference_dtype = torch.bfloat16
|
| 324 |
+
else:
|
| 325 |
+
inference_dtype = torch.float32
|
| 326 |
+
|
| 327 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
| 328 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
| 329 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
| 330 |
+
|
| 331 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
| 332 |
+
|
| 333 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
| 334 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
| 335 |
+
|
| 336 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 337 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 338 |
+
if self.config.allow_tf32:
|
| 339 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 340 |
+
|
| 341 |
+
self.optimizer = self._setup_optimizer(
|
| 342 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
| 346 |
+
self.sd_pipeline.tokenizer(
|
| 347 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
| 348 |
+
return_tensors="pt",
|
| 349 |
+
padding="max_length",
|
| 350 |
+
truncation=True,
|
| 351 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 352 |
+
).input_ids.to(self.accelerator.device)
|
| 353 |
+
)[0]
|
| 354 |
+
|
| 355 |
+
if config.per_prompt_stat_tracking:
|
| 356 |
+
self.stat_tracker = PerPromptStatTracker(
|
| 357 |
+
config.per_prompt_stat_tracking_buffer_size,
|
| 358 |
+
config.per_prompt_stat_tracking_min_count,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
| 362 |
+
# more memory
|
| 363 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
| 364 |
+
|
| 365 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
| 366 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 367 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 368 |
+
else:
|
| 369 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 370 |
+
|
| 371 |
+
if self.config.async_reward_computation:
|
| 372 |
+
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
| 373 |
+
|
| 374 |
+
if config.resume_from:
|
| 375 |
+
logger.info(f"Resuming from {config.resume_from}")
|
| 376 |
+
self.accelerator.load_state(config.resume_from)
|
| 377 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
| 378 |
+
else:
|
| 379 |
+
self.first_epoch = 0
|
| 380 |
+
|
| 381 |
+
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
| 382 |
+
if not is_async:
|
| 383 |
+
rewards = []
|
| 384 |
+
for images, prompts, prompt_metadata in prompt_image_pairs:
|
| 385 |
+
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
| 386 |
+
rewards.append(
|
| 387 |
+
(
|
| 388 |
+
torch.as_tensor(reward, device=self.accelerator.device),
|
| 389 |
+
reward_metadata,
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
| 394 |
+
rewards = [
|
| 395 |
+
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
| 396 |
+
for reward, reward_metadata in rewards
|
| 397 |
+
]
|
| 398 |
+
|
| 399 |
+
return zip(*rewards)
|
| 400 |
+
|
| 401 |
+
def step(self, epoch: int, global_step: int):
|
| 402 |
+
"""
|
| 403 |
+
Perform a single step of training.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
epoch (int): The current epoch.
|
| 407 |
+
global_step (int): The current global step.
|
| 408 |
+
|
| 409 |
+
Side Effects:
|
| 410 |
+
- Model weights are updated
|
| 411 |
+
- Logs the statistics to the accelerator trackers.
|
| 412 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
global_step (int): The updated global step.
|
| 416 |
+
|
| 417 |
+
"""
|
| 418 |
+
samples, prompt_image_data = self._generate_samples(
|
| 419 |
+
iterations=self.config.sample_num_batches_per_epoch,
|
| 420 |
+
batch_size=self.config.sample_batch_size,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
| 424 |
+
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
| 425 |
+
rewards, rewards_metadata = self.compute_rewards(
|
| 426 |
+
prompt_image_data, is_async=self.config.async_reward_computation
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
for i, image_data in enumerate(prompt_image_data):
|
| 430 |
+
image_data.extend([rewards[i], rewards_metadata[i]])
|
| 431 |
+
|
| 432 |
+
if self.image_samples_callback is not None:
|
| 433 |
+
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
| 434 |
+
|
| 435 |
+
rewards = torch.cat(rewards)
|
| 436 |
+
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
| 437 |
+
|
| 438 |
+
self.accelerator.log(
|
| 439 |
+
{
|
| 440 |
+
"reward": rewards,
|
| 441 |
+
"epoch": epoch,
|
| 442 |
+
"reward_mean": rewards.mean(),
|
| 443 |
+
"reward_std": rewards.std(),
|
| 444 |
+
},
|
| 445 |
+
step=global_step,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
if self.config.per_prompt_stat_tracking:
|
| 449 |
+
# gather the prompts across processes
|
| 450 |
+
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
| 451 |
+
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
| 452 |
+
advantages = self.stat_tracker.update(prompts, rewards)
|
| 453 |
+
else:
|
| 454 |
+
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
| 455 |
+
|
| 456 |
+
# ungather advantages; keep the entries corresponding to the samples on this process
|
| 457 |
+
samples["advantages"] = (
|
| 458 |
+
torch.as_tensor(advantages)
|
| 459 |
+
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
| 460 |
+
.to(self.accelerator.device)
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
del samples["prompt_ids"]
|
| 464 |
+
|
| 465 |
+
total_batch_size, num_timesteps = samples["timesteps"].shape
|
| 466 |
+
|
| 467 |
+
for inner_epoch in range(self.config.train_num_inner_epochs):
|
| 468 |
+
# shuffle samples along batch dimension
|
| 469 |
+
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
| 470 |
+
samples = {k: v[perm] for k, v in samples.items()}
|
| 471 |
+
|
| 472 |
+
# shuffle along time dimension independently for each sample
|
| 473 |
+
# still trying to understand the code below
|
| 474 |
+
perms = torch.stack(
|
| 475 |
+
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
| 479 |
+
samples[key] = samples[key][
|
| 480 |
+
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
| 481 |
+
perms,
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
original_keys = samples.keys()
|
| 485 |
+
original_values = samples.values()
|
| 486 |
+
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
| 487 |
+
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
| 488 |
+
|
| 489 |
+
# Transpose the list of original values
|
| 490 |
+
transposed_values = zip(*reshaped_values)
|
| 491 |
+
# Create new dictionaries for each row of transposed values
|
| 492 |
+
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
| 493 |
+
|
| 494 |
+
self.sd_pipeline.unet.train()
|
| 495 |
+
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
| 496 |
+
# ensure optimization step at the end of the inner epoch
|
| 497 |
+
if not self.accelerator.sync_gradients:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
| 503 |
+
self.accelerator.save_state()
|
| 504 |
+
|
| 505 |
+
return global_step
|
| 506 |
+
|
| 507 |
+
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
| 508 |
+
"""
|
| 509 |
+
Calculate the loss for a batch of an unpacked sample
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
latents (torch.Tensor):
|
| 513 |
+
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
| 514 |
+
timesteps (torch.Tensor):
|
| 515 |
+
The timesteps sampled from the diffusion model, shape: [batch_size]
|
| 516 |
+
next_latents (torch.Tensor):
|
| 517 |
+
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
| 518 |
+
log_probs (torch.Tensor):
|
| 519 |
+
The log probabilities of the latents, shape: [batch_size]
|
| 520 |
+
advantages (torch.Tensor):
|
| 521 |
+
The advantages of the latents, shape: [batch_size]
|
| 522 |
+
embeds (torch.Tensor):
|
| 523 |
+
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
| 524 |
+
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
| 528 |
+
(all of these are of shape (1,))
|
| 529 |
+
"""
|
| 530 |
+
with self.autocast():
|
| 531 |
+
if self.config.train_cfg:
|
| 532 |
+
noise_pred = self.sd_pipeline.unet(
|
| 533 |
+
torch.cat([latents] * 2),
|
| 534 |
+
torch.cat([timesteps] * 2),
|
| 535 |
+
embeds,
|
| 536 |
+
).sample
|
| 537 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 538 |
+
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
| 539 |
+
noise_pred_text - noise_pred_uncond
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
noise_pred = self.sd_pipeline.unet(
|
| 543 |
+
latents,
|
| 544 |
+
timesteps,
|
| 545 |
+
embeds,
|
| 546 |
+
).sample
|
| 547 |
+
# compute the log prob of next_latents given latents under the current model
|
| 548 |
+
|
| 549 |
+
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
| 550 |
+
noise_pred,
|
| 551 |
+
timesteps,
|
| 552 |
+
latents,
|
| 553 |
+
eta=self.config.sample_eta,
|
| 554 |
+
prev_sample=next_latents,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
log_prob = scheduler_step_output.log_probs
|
| 558 |
+
|
| 559 |
+
advantages = torch.clamp(
|
| 560 |
+
advantages,
|
| 561 |
+
-self.config.train_adv_clip_max,
|
| 562 |
+
self.config.train_adv_clip_max,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
ratio = torch.exp(log_prob - log_probs)
|
| 566 |
+
|
| 567 |
+
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
| 568 |
+
|
| 569 |
+
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
| 570 |
+
|
| 571 |
+
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
| 572 |
+
|
| 573 |
+
return loss, approx_kl, clipfrac
|
| 574 |
+
|
| 575 |
+
def loss(
|
| 576 |
+
self,
|
| 577 |
+
advantages: torch.Tensor,
|
| 578 |
+
clip_range: float,
|
| 579 |
+
ratio: torch.Tensor,
|
| 580 |
+
):
|
| 581 |
+
unclipped_loss = -advantages * ratio
|
| 582 |
+
clipped_loss = -advantages * torch.clamp(
|
| 583 |
+
ratio,
|
| 584 |
+
1.0 - clip_range,
|
| 585 |
+
1.0 + clip_range,
|
| 586 |
+
)
|
| 587 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
| 588 |
+
|
| 589 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
| 590 |
+
if self.config.train_use_8bit_adam:
|
| 591 |
+
import bitsandbytes
|
| 592 |
+
|
| 593 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
| 594 |
+
else:
|
| 595 |
+
optimizer_cls = torch.optim.AdamW
|
| 596 |
+
|
| 597 |
+
return optimizer_cls(
|
| 598 |
+
trainable_layers_parameters,
|
| 599 |
+
lr=self.config.train_learning_rate,
|
| 600 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
| 601 |
+
weight_decay=self.config.train_adam_weight_decay,
|
| 602 |
+
eps=self.config.train_adam_epsilon,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
def _save_model_hook(self, models, weights, output_dir):
|
| 606 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
| 607 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
| 608 |
+
|
| 609 |
+
def _load_model_hook(self, models, input_dir):
|
| 610 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
| 611 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
| 612 |
+
|
| 613 |
+
def _generate_samples(self, iterations, batch_size):
|
| 614 |
+
"""
|
| 615 |
+
Generate samples from the model
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
iterations (int): Number of iterations to generate samples for
|
| 619 |
+
batch_size (int): Batch size to use for sampling
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
| 623 |
+
"""
|
| 624 |
+
samples = []
|
| 625 |
+
prompt_image_pairs = []
|
| 626 |
+
self.sd_pipeline.unet.eval()
|
| 627 |
+
|
| 628 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
| 629 |
+
|
| 630 |
+
for _ in range(iterations):
|
| 631 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
| 632 |
+
|
| 633 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
| 634 |
+
prompts,
|
| 635 |
+
return_tensors="pt",
|
| 636 |
+
padding="max_length",
|
| 637 |
+
truncation=True,
|
| 638 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 639 |
+
).input_ids.to(self.accelerator.device)
|
| 640 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
| 641 |
+
|
| 642 |
+
with self.autocast():
|
| 643 |
+
sd_output = self.sd_pipeline(
|
| 644 |
+
prompt_embeds=prompt_embeds,
|
| 645 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 646 |
+
num_inference_steps=self.config.sample_num_steps,
|
| 647 |
+
guidance_scale=self.config.sample_guidance_scale,
|
| 648 |
+
eta=self.config.sample_eta,
|
| 649 |
+
output_type="pt",
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
images = sd_output.images
|
| 653 |
+
latents = sd_output.latents
|
| 654 |
+
log_probs = sd_output.log_probs
|
| 655 |
+
|
| 656 |
+
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
| 657 |
+
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
| 658 |
+
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
| 659 |
+
|
| 660 |
+
samples.append(
|
| 661 |
+
{
|
| 662 |
+
"prompt_ids": prompt_ids,
|
| 663 |
+
"prompt_embeds": prompt_embeds,
|
| 664 |
+
"timesteps": timesteps,
|
| 665 |
+
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
| 666 |
+
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
| 667 |
+
"log_probs": log_probs,
|
| 668 |
+
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
| 669 |
+
}
|
| 670 |
+
)
|
| 671 |
+
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
| 672 |
+
|
| 673 |
+
return samples, prompt_image_pairs
|
| 674 |
+
|
| 675 |
+
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
| 676 |
+
"""
|
| 677 |
+
Train on a batch of samples. Main training segment
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
inner_epoch (int): The current inner epoch
|
| 681 |
+
epoch (int): The current epoch
|
| 682 |
+
global_step (int): The current global step
|
| 683 |
+
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
| 684 |
+
|
| 685 |
+
Side Effects:
|
| 686 |
+
- Model weights are updated
|
| 687 |
+
- Logs the statistics to the accelerator trackers.
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
global_step (int): The updated global step
|
| 691 |
+
"""
|
| 692 |
+
info = defaultdict(list)
|
| 693 |
+
for _i, sample in enumerate(batched_samples):
|
| 694 |
+
if self.config.train_cfg:
|
| 695 |
+
# concat negative prompts to sample prompts to avoid two forward passes
|
| 696 |
+
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
| 697 |
+
else:
|
| 698 |
+
embeds = sample["prompt_embeds"]
|
| 699 |
+
|
| 700 |
+
for j in range(self.num_train_timesteps):
|
| 701 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
| 702 |
+
loss, approx_kl, clipfrac = self.calculate_loss(
|
| 703 |
+
sample["latents"][:, j],
|
| 704 |
+
sample["timesteps"][:, j],
|
| 705 |
+
sample["next_latents"][:, j],
|
| 706 |
+
sample["log_probs"][:, j],
|
| 707 |
+
sample["advantages"],
|
| 708 |
+
embeds,
|
| 709 |
+
)
|
| 710 |
+
info["approx_kl"].append(approx_kl)
|
| 711 |
+
info["clipfrac"].append(clipfrac)
|
| 712 |
+
info["loss"].append(loss)
|
| 713 |
+
|
| 714 |
+
self.accelerator.backward(loss)
|
| 715 |
+
if self.accelerator.sync_gradients:
|
| 716 |
+
self.accelerator.clip_grad_norm_(
|
| 717 |
+
self.trainable_layers.parameters()
|
| 718 |
+
if not isinstance(self.trainable_layers, list)
|
| 719 |
+
else self.trainable_layers,
|
| 720 |
+
self.config.train_max_grad_norm,
|
| 721 |
+
)
|
| 722 |
+
self.optimizer.step()
|
| 723 |
+
self.optimizer.zero_grad()
|
| 724 |
+
|
| 725 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 726 |
+
if self.accelerator.sync_gradients:
|
| 727 |
+
# log training-related stuff
|
| 728 |
+
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
| 729 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
| 730 |
+
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
| 731 |
+
self.accelerator.log(info, step=global_step)
|
| 732 |
+
global_step += 1
|
| 733 |
+
info = defaultdict(list)
|
| 734 |
+
return global_step
|
| 735 |
+
|
| 736 |
+
def _config_check(self) -> tuple[bool, str]:
|
| 737 |
+
samples_per_epoch = (
|
| 738 |
+
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
| 739 |
+
)
|
| 740 |
+
total_train_batch_size = (
|
| 741 |
+
self.config.train_batch_size
|
| 742 |
+
* self.accelerator.num_processes
|
| 743 |
+
* self.config.train_gradient_accumulation_steps
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
| 747 |
+
return (
|
| 748 |
+
False,
|
| 749 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
| 750 |
+
)
|
| 751 |
+
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
| 752 |
+
return (
|
| 753 |
+
False,
|
| 754 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
| 755 |
+
)
|
| 756 |
+
if not samples_per_epoch % total_train_batch_size == 0:
|
| 757 |
+
return (
|
| 758 |
+
False,
|
| 759 |
+
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
| 760 |
+
)
|
| 761 |
+
return True, ""
|
| 762 |
+
|
| 763 |
+
def train(self, epochs: Optional[int] = None):
|
| 764 |
+
"""
|
| 765 |
+
Train the model for a given number of epochs
|
| 766 |
+
"""
|
| 767 |
+
global_step = 0
|
| 768 |
+
if epochs is None:
|
| 769 |
+
epochs = self.config.num_epochs
|
| 770 |
+
for epoch in range(self.first_epoch, epochs):
|
| 771 |
+
global_step = self.step(epoch, global_step)
|
| 772 |
+
|
| 773 |
+
def _save_pretrained(self, save_directory):
|
| 774 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
| 775 |
+
self.create_model_card()
|
| 776 |
+
|
| 777 |
+
def create_model_card(
|
| 778 |
+
self,
|
| 779 |
+
model_name: Optional[str] = None,
|
| 780 |
+
dataset_name: Optional[str] = None,
|
| 781 |
+
tags: Union[str, list[str], None] = None,
|
| 782 |
+
):
|
| 783 |
+
"""
|
| 784 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 788 |
+
Name of the model.
|
| 789 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 790 |
+
Name of the dataset used for training.
|
| 791 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 792 |
+
Tags to be associated with the model card.
|
| 793 |
+
"""
|
| 794 |
+
if not self.is_world_process_zero():
|
| 795 |
+
return
|
| 796 |
+
|
| 797 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 798 |
+
base_model = self.model.config._name_or_path
|
| 799 |
+
else:
|
| 800 |
+
base_model = None
|
| 801 |
+
|
| 802 |
+
tags = tags or []
|
| 803 |
+
if isinstance(tags, str):
|
| 804 |
+
tags = [tags]
|
| 805 |
+
|
| 806 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 807 |
+
tags.append("unsloth")
|
| 808 |
+
|
| 809 |
+
citation = textwrap.dedent("""\
|
| 810 |
+
@inproceedings{black2024training,
|
| 811 |
+
title = {{Training Diffusion Models with Reinforcement Learning}},
|
| 812 |
+
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
| 813 |
+
year = 2024,
|
| 814 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
| 815 |
+
publisher = {OpenReview.net},
|
| 816 |
+
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
| 817 |
+
}""")
|
| 818 |
+
|
| 819 |
+
model_card = generate_model_card(
|
| 820 |
+
base_model=base_model,
|
| 821 |
+
model_name=model_name,
|
| 822 |
+
hub_model_id=self.hub_model_id,
|
| 823 |
+
dataset_name=dataset_name,
|
| 824 |
+
tags=tags,
|
| 825 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 826 |
+
comet_url=get_comet_experiment_url(),
|
| 827 |
+
trainer_name="DDPO",
|
| 828 |
+
trainer_citation=citation,
|
| 829 |
+
paper_title="Training Diffusion Models with Reinforcement Learning",
|
| 830 |
+
paper_id="2305.13301",
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 834 |
+
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
| 838 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
| 839 |
+
As of now only Stable Diffusion based pipelines are supported
|
| 840 |
+
|
| 841 |
+
Attributes:
|
| 842 |
+
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
| 843 |
+
details.
|
| 844 |
+
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
| 845 |
+
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
| 846 |
+
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
| 847 |
+
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
| 848 |
+
|
| 849 |
+
"""
|
| 850 |
+
def __init__(
|
| 851 |
+
self,
|
| 852 |
+
config,
|
| 853 |
+
reward_function,
|
| 854 |
+
prompt_function,
|
| 855 |
+
sd_pipeline,
|
| 856 |
+
image_samples_hook = None,
|
| 857 |
+
**kwargs
|
| 858 |
+
):
|
| 859 |
+
if args is None: args = UnslothDDPOConfig()
|
| 860 |
+
other_metrics = []
|
| 861 |
+
|
| 862 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 863 |
+
PatchRLStatistics('ddpo_trainer', other_metrics)
|
| 864 |
+
|
| 865 |
+
super().__init__(
|
| 866 |
+
config = config,
|
| 867 |
+
reward_function = reward_function,
|
| 868 |
+
prompt_function = prompt_function,
|
| 869 |
+
sd_pipeline = sd_pipeline,
|
| 870 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
| 871 |
+
|
| 872 |
+
pass
|
unsloth_compiled_cache/UnslothDPOTrainer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/UnslothGKDTrainer.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, deepspeed, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothGKDConfig(GKDConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for [`GKDTrainer`].
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 50 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 51 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
| 52 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
| 53 |
+
student-generated outputs).
|
| 54 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
| 55 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
| 56 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
| 57 |
+
max_new_tokens (`int`, *optional*, defaults to `128`):
|
| 58 |
+
Maximum number of tokens to generate per completion.
|
| 59 |
+
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
| 60 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
| 61 |
+
being trained.
|
| 62 |
+
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
| 63 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
| 64 |
+
from a string.
|
| 65 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether to disable dropout in the model.
|
| 67 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
| 68 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
| 69 |
+
on teacher-generated output).
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 73 |
+
default = None,
|
| 74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 75 |
+
)
|
| 76 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 77 |
+
default = -1,
|
| 78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 79 |
+
)
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
output_dir = None,
|
| 83 |
+
overwrite_output_dir = None,
|
| 84 |
+
do_train = False,
|
| 85 |
+
do_eval = False,
|
| 86 |
+
do_predict = False,
|
| 87 |
+
eval_strategy = 'no',
|
| 88 |
+
prediction_loss_only = False,
|
| 89 |
+
per_device_train_batch_size = 4,
|
| 90 |
+
per_device_eval_batch_size = 4,
|
| 91 |
+
per_gpu_train_batch_size = None,
|
| 92 |
+
per_gpu_eval_batch_size = None,
|
| 93 |
+
gradient_accumulation_steps = 2,
|
| 94 |
+
eval_accumulation_steps = 2,
|
| 95 |
+
eval_delay = 0,
|
| 96 |
+
torch_empty_cache_steps = 250,
|
| 97 |
+
learning_rate = 5e-05,
|
| 98 |
+
weight_decay = 0.01,
|
| 99 |
+
adam_beta1 = 0.9,
|
| 100 |
+
adam_beta2 = 0.999,
|
| 101 |
+
adam_epsilon = 1e-08,
|
| 102 |
+
max_grad_norm = 1.0,
|
| 103 |
+
num_train_epochs = 3.0,
|
| 104 |
+
max_steps = -1,
|
| 105 |
+
lr_scheduler_type = 'linear',
|
| 106 |
+
warmup_ratio = 0.1,
|
| 107 |
+
warmup_steps = 0,
|
| 108 |
+
log_level = 'passive',
|
| 109 |
+
log_level_replica = 'warning',
|
| 110 |
+
log_on_each_node = True,
|
| 111 |
+
logging_dir = None,
|
| 112 |
+
logging_strategy = 'steps',
|
| 113 |
+
logging_first_step = False,
|
| 114 |
+
logging_steps = 1,
|
| 115 |
+
logging_nan_inf_filter = False,
|
| 116 |
+
save_strategy = 'steps',
|
| 117 |
+
save_steps = 500,
|
| 118 |
+
save_total_limit = None,
|
| 119 |
+
save_safetensors = True,
|
| 120 |
+
save_on_each_node = False,
|
| 121 |
+
save_only_model = False,
|
| 122 |
+
restore_callback_states_from_checkpoint = False,
|
| 123 |
+
no_cuda = False,
|
| 124 |
+
use_cpu = False,
|
| 125 |
+
use_mps_device = False,
|
| 126 |
+
seed = 3407,
|
| 127 |
+
data_seed = 3407,
|
| 128 |
+
jit_mode_eval = False,
|
| 129 |
+
use_ipex = False,
|
| 130 |
+
bf16 = False,
|
| 131 |
+
fp16 = False,
|
| 132 |
+
fp16_opt_level = 'O1',
|
| 133 |
+
half_precision_backend = 'auto',
|
| 134 |
+
bf16_full_eval = False,
|
| 135 |
+
fp16_full_eval = False,
|
| 136 |
+
tf32 = None,
|
| 137 |
+
local_rank = -1,
|
| 138 |
+
ddp_backend = None,
|
| 139 |
+
tpu_num_cores = None,
|
| 140 |
+
tpu_metrics_debug = False,
|
| 141 |
+
debug = '',
|
| 142 |
+
dataloader_drop_last = False,
|
| 143 |
+
eval_steps = None,
|
| 144 |
+
dataloader_num_workers = 0,
|
| 145 |
+
dataloader_prefetch_factor = None,
|
| 146 |
+
past_index = -1,
|
| 147 |
+
run_name = None,
|
| 148 |
+
disable_tqdm = None,
|
| 149 |
+
remove_unused_columns = True,
|
| 150 |
+
label_names = None,
|
| 151 |
+
load_best_model_at_end = False,
|
| 152 |
+
metric_for_best_model = None,
|
| 153 |
+
greater_is_better = None,
|
| 154 |
+
ignore_data_skip = False,
|
| 155 |
+
fsdp = '',
|
| 156 |
+
fsdp_min_num_params = 0,
|
| 157 |
+
fsdp_config = None,
|
| 158 |
+
tp_size = 0,
|
| 159 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 160 |
+
accelerator_config = None,
|
| 161 |
+
deepspeed = None,
|
| 162 |
+
label_smoothing_factor = 0.0,
|
| 163 |
+
optim = 'adamw_8bit',
|
| 164 |
+
optim_args = None,
|
| 165 |
+
adafactor = False,
|
| 166 |
+
group_by_length = False,
|
| 167 |
+
length_column_name = 'length',
|
| 168 |
+
report_to = None,
|
| 169 |
+
ddp_find_unused_parameters = None,
|
| 170 |
+
ddp_bucket_cap_mb = None,
|
| 171 |
+
ddp_broadcast_buffers = None,
|
| 172 |
+
dataloader_pin_memory = True,
|
| 173 |
+
dataloader_persistent_workers = False,
|
| 174 |
+
skip_memory_metrics = True,
|
| 175 |
+
use_legacy_prediction_loop = False,
|
| 176 |
+
push_to_hub = False,
|
| 177 |
+
resume_from_checkpoint = None,
|
| 178 |
+
hub_model_id = None,
|
| 179 |
+
hub_strategy = 'every_save',
|
| 180 |
+
hub_token = None,
|
| 181 |
+
hub_private_repo = None,
|
| 182 |
+
hub_always_push = False,
|
| 183 |
+
gradient_checkpointing = False,
|
| 184 |
+
gradient_checkpointing_kwargs = None,
|
| 185 |
+
include_inputs_for_metrics = False,
|
| 186 |
+
eval_do_concat_batches = True,
|
| 187 |
+
fp16_backend = 'auto',
|
| 188 |
+
evaluation_strategy = None,
|
| 189 |
+
push_to_hub_model_id = None,
|
| 190 |
+
push_to_hub_organization = None,
|
| 191 |
+
push_to_hub_token = None,
|
| 192 |
+
mp_parameters = '',
|
| 193 |
+
auto_find_batch_size = False,
|
| 194 |
+
full_determinism = False,
|
| 195 |
+
torchdynamo = None,
|
| 196 |
+
ray_scope = 'last',
|
| 197 |
+
ddp_timeout = 1800,
|
| 198 |
+
torch_compile = False,
|
| 199 |
+
torch_compile_backend = None,
|
| 200 |
+
torch_compile_mode = None,
|
| 201 |
+
dispatch_batches = None,
|
| 202 |
+
split_batches = None,
|
| 203 |
+
include_tokens_per_second = False,
|
| 204 |
+
include_num_input_tokens_seen = False,
|
| 205 |
+
neftune_noise_alpha = None,
|
| 206 |
+
optim_target_modules = None,
|
| 207 |
+
batch_eval_metrics = False,
|
| 208 |
+
eval_on_start = False,
|
| 209 |
+
use_liger_kernel = False,
|
| 210 |
+
eval_use_gather_object = False,
|
| 211 |
+
average_tokens_across_devices = False,
|
| 212 |
+
model_init_kwargs = None,
|
| 213 |
+
use_liger = False,
|
| 214 |
+
dataset_text_field = 'text',
|
| 215 |
+
dataset_kwargs = None,
|
| 216 |
+
dataset_num_proc = None,
|
| 217 |
+
max_seq_length = None,
|
| 218 |
+
packing = False,
|
| 219 |
+
eval_packing = None,
|
| 220 |
+
dataset_batch_size = None,
|
| 221 |
+
num_of_sequences = None,
|
| 222 |
+
chars_per_token = None,
|
| 223 |
+
temperature = 0.9,
|
| 224 |
+
lmbda = 0.5,
|
| 225 |
+
beta = 0.5,
|
| 226 |
+
max_new_tokens = 128,
|
| 227 |
+
teacher_model_name_or_path = None,
|
| 228 |
+
teacher_model_init_kwargs = None,
|
| 229 |
+
disable_dropout = True,
|
| 230 |
+
seq_kd = False,
|
| 231 |
+
vllm_sampling_params = None,
|
| 232 |
+
unsloth_num_chunks = -1,
|
| 233 |
+
**kwargs,
|
| 234 |
+
):
|
| 235 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 236 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 237 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 238 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 239 |
+
save_strategy = 'no'
|
| 240 |
+
if dataset_num_proc is None:
|
| 241 |
+
from multiprocessing import cpu_count
|
| 242 |
+
dataset_num_proc = cpu_count()
|
| 243 |
+
|
| 244 |
+
super().__init__(
|
| 245 |
+
output_dir = output_dir,
|
| 246 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 247 |
+
do_train = do_train,
|
| 248 |
+
do_eval = do_eval,
|
| 249 |
+
do_predict = do_predict,
|
| 250 |
+
eval_strategy = eval_strategy,
|
| 251 |
+
prediction_loss_only = prediction_loss_only,
|
| 252 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 253 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 254 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 255 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 256 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 257 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 258 |
+
eval_delay = eval_delay,
|
| 259 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 260 |
+
learning_rate = learning_rate,
|
| 261 |
+
weight_decay = weight_decay,
|
| 262 |
+
adam_beta1 = adam_beta1,
|
| 263 |
+
adam_beta2 = adam_beta2,
|
| 264 |
+
adam_epsilon = adam_epsilon,
|
| 265 |
+
max_grad_norm = max_grad_norm,
|
| 266 |
+
num_train_epochs = num_train_epochs,
|
| 267 |
+
max_steps = max_steps,
|
| 268 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 269 |
+
warmup_ratio = warmup_ratio,
|
| 270 |
+
warmup_steps = warmup_steps,
|
| 271 |
+
log_level = log_level,
|
| 272 |
+
log_level_replica = log_level_replica,
|
| 273 |
+
log_on_each_node = log_on_each_node,
|
| 274 |
+
logging_dir = logging_dir,
|
| 275 |
+
logging_strategy = logging_strategy,
|
| 276 |
+
logging_first_step = logging_first_step,
|
| 277 |
+
logging_steps = logging_steps,
|
| 278 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 279 |
+
save_strategy = save_strategy,
|
| 280 |
+
save_steps = save_steps,
|
| 281 |
+
save_total_limit = save_total_limit,
|
| 282 |
+
save_safetensors = save_safetensors,
|
| 283 |
+
save_on_each_node = save_on_each_node,
|
| 284 |
+
save_only_model = save_only_model,
|
| 285 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 286 |
+
no_cuda = no_cuda,
|
| 287 |
+
use_cpu = use_cpu,
|
| 288 |
+
use_mps_device = use_mps_device,
|
| 289 |
+
seed = seed,
|
| 290 |
+
data_seed = data_seed,
|
| 291 |
+
jit_mode_eval = jit_mode_eval,
|
| 292 |
+
use_ipex = use_ipex,
|
| 293 |
+
bf16 = bf16,
|
| 294 |
+
fp16 = fp16,
|
| 295 |
+
fp16_opt_level = fp16_opt_level,
|
| 296 |
+
half_precision_backend = half_precision_backend,
|
| 297 |
+
bf16_full_eval = bf16_full_eval,
|
| 298 |
+
fp16_full_eval = fp16_full_eval,
|
| 299 |
+
tf32 = tf32,
|
| 300 |
+
local_rank = local_rank,
|
| 301 |
+
ddp_backend = ddp_backend,
|
| 302 |
+
tpu_num_cores = tpu_num_cores,
|
| 303 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 304 |
+
debug = debug,
|
| 305 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 306 |
+
eval_steps = eval_steps,
|
| 307 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 308 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 309 |
+
past_index = past_index,
|
| 310 |
+
run_name = run_name,
|
| 311 |
+
disable_tqdm = disable_tqdm,
|
| 312 |
+
remove_unused_columns = remove_unused_columns,
|
| 313 |
+
label_names = label_names,
|
| 314 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 315 |
+
metric_for_best_model = metric_for_best_model,
|
| 316 |
+
greater_is_better = greater_is_better,
|
| 317 |
+
ignore_data_skip = ignore_data_skip,
|
| 318 |
+
fsdp = fsdp,
|
| 319 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 320 |
+
fsdp_config = fsdp_config,
|
| 321 |
+
tp_size = tp_size,
|
| 322 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 323 |
+
accelerator_config = accelerator_config,
|
| 324 |
+
deepspeed = deepspeed,
|
| 325 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 326 |
+
optim = optim,
|
| 327 |
+
optim_args = optim_args,
|
| 328 |
+
adafactor = adafactor,
|
| 329 |
+
group_by_length = group_by_length,
|
| 330 |
+
length_column_name = length_column_name,
|
| 331 |
+
report_to = report_to,
|
| 332 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 333 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 334 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 335 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 336 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 337 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 338 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 339 |
+
push_to_hub = push_to_hub,
|
| 340 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 341 |
+
hub_model_id = hub_model_id,
|
| 342 |
+
hub_strategy = hub_strategy,
|
| 343 |
+
hub_token = hub_token,
|
| 344 |
+
hub_private_repo = hub_private_repo,
|
| 345 |
+
hub_always_push = hub_always_push,
|
| 346 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 347 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 348 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 349 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 350 |
+
fp16_backend = fp16_backend,
|
| 351 |
+
evaluation_strategy = evaluation_strategy,
|
| 352 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 353 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 354 |
+
push_to_hub_token = push_to_hub_token,
|
| 355 |
+
mp_parameters = mp_parameters,
|
| 356 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 357 |
+
full_determinism = full_determinism,
|
| 358 |
+
torchdynamo = torchdynamo,
|
| 359 |
+
ray_scope = ray_scope,
|
| 360 |
+
ddp_timeout = ddp_timeout,
|
| 361 |
+
torch_compile = torch_compile,
|
| 362 |
+
torch_compile_backend = torch_compile_backend,
|
| 363 |
+
torch_compile_mode = torch_compile_mode,
|
| 364 |
+
dispatch_batches = dispatch_batches,
|
| 365 |
+
split_batches = split_batches,
|
| 366 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 367 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 368 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 369 |
+
optim_target_modules = optim_target_modules,
|
| 370 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 371 |
+
eval_on_start = eval_on_start,
|
| 372 |
+
use_liger_kernel = use_liger_kernel,
|
| 373 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 374 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 375 |
+
model_init_kwargs = model_init_kwargs,
|
| 376 |
+
use_liger = use_liger,
|
| 377 |
+
dataset_text_field = dataset_text_field,
|
| 378 |
+
dataset_kwargs = dataset_kwargs,
|
| 379 |
+
dataset_num_proc = dataset_num_proc,
|
| 380 |
+
max_seq_length = max_seq_length,
|
| 381 |
+
packing = packing,
|
| 382 |
+
eval_packing = eval_packing,
|
| 383 |
+
dataset_batch_size = dataset_batch_size,
|
| 384 |
+
num_of_sequences = num_of_sequences,
|
| 385 |
+
chars_per_token = chars_per_token,
|
| 386 |
+
temperature = temperature,
|
| 387 |
+
lmbda = lmbda,
|
| 388 |
+
beta = beta,
|
| 389 |
+
max_new_tokens = max_new_tokens,
|
| 390 |
+
teacher_model_name_or_path = teacher_model_name_or_path,
|
| 391 |
+
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
| 392 |
+
disable_dropout = disable_dropout,
|
| 393 |
+
seq_kd = seq_kd,**kwargs)
|
| 394 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 395 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 396 |
+
pass
|
| 397 |
+
|
| 398 |
+
class _UnslothGKDTrainer(SFTTrainer):
|
| 399 |
+
_tag_names = ["trl", "gkd"]
|
| 400 |
+
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 404 |
+
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 405 |
+
args: Optional[GKDConfig] = None,
|
| 406 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
| 407 |
+
train_dataset: Optional[Dataset] = None,
|
| 408 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 409 |
+
processing_class: Optional[
|
| 410 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 411 |
+
] = None,
|
| 412 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 413 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 414 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 415 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 416 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 417 |
+
formatting_func: Optional[Callable] = None,
|
| 418 |
+
):
|
| 419 |
+
# add remove_unused_columns=False to the dataclass args
|
| 420 |
+
args.remove_unused_columns = False
|
| 421 |
+
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
|
| 422 |
+
|
| 423 |
+
super().__init__(
|
| 424 |
+
model,
|
| 425 |
+
args=args,
|
| 426 |
+
data_collator=data_collator,
|
| 427 |
+
train_dataset=train_dataset,
|
| 428 |
+
eval_dataset=eval_dataset,
|
| 429 |
+
processing_class=processing_class,
|
| 430 |
+
compute_metrics=compute_metrics,
|
| 431 |
+
callbacks=callbacks,
|
| 432 |
+
optimizers=optimizers,
|
| 433 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 434 |
+
peft_config=peft_config,
|
| 435 |
+
formatting_func=formatting_func,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if args.teacher_model_init_kwargs is None:
|
| 439 |
+
teacher_model_init_kwargs = {}
|
| 440 |
+
elif not isinstance(teacher_model, str):
|
| 441 |
+
raise ValueError(
|
| 442 |
+
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
| 446 |
+
teacher_model_init_kwargs["torch_dtype"] = (
|
| 447 |
+
teacher_model_init_kwargs["torch_dtype"]
|
| 448 |
+
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
| 449 |
+
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if isinstance(teacher_model, str):
|
| 453 |
+
if args.use_liger:
|
| 454 |
+
teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
| 455 |
+
else:
|
| 456 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
| 457 |
+
|
| 458 |
+
# Disable dropout in the model
|
| 459 |
+
if args.disable_dropout:
|
| 460 |
+
disable_dropout_in_model(self.model)
|
| 461 |
+
|
| 462 |
+
if self.is_deepspeed_enabled:
|
| 463 |
+
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
| 464 |
+
else:
|
| 465 |
+
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
| 466 |
+
|
| 467 |
+
self.lmbda = args.lmbda
|
| 468 |
+
self.beta = args.beta
|
| 469 |
+
self.temperature = args.temperature
|
| 470 |
+
self.seq_kd = args.seq_kd
|
| 471 |
+
|
| 472 |
+
self.generation_config = GenerationConfig(
|
| 473 |
+
max_new_tokens=args.max_new_tokens,
|
| 474 |
+
temperature=args.temperature,
|
| 475 |
+
do_sample=True,
|
| 476 |
+
top_k=0,
|
| 477 |
+
use_cache=False if args.gradient_checkpointing else True,
|
| 478 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 479 |
+
)
|
| 480 |
+
# Set custom EOS tokens if they are specified by the model's generation
|
| 481 |
+
# config. This is important for models with the Llama 3 chat template,
|
| 482 |
+
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
| 483 |
+
# turns or messages.
|
| 484 |
+
if (
|
| 485 |
+
hasattr(self.model.generation_config, "eos_token_id")
|
| 486 |
+
and self.model.generation_config.eos_token_id is not None
|
| 487 |
+
):
|
| 488 |
+
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
| 489 |
+
|
| 490 |
+
def _prepare_dataset(self, dataset, *args):
|
| 491 |
+
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
| 492 |
+
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
| 493 |
+
dataset = dataset.add_column("_messages", dataset["messages"])
|
| 494 |
+
dataset = super()._prepare_dataset(dataset, *args)
|
| 495 |
+
dataset = dataset.rename_column("_messages", "messages")
|
| 496 |
+
return dataset
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
def generalized_jsd_loss(
|
| 500 |
+
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
| 501 |
+
):
|
| 502 |
+
"""
|
| 503 |
+
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
| 504 |
+
of https://huggingface.co/papers/2306.13649 for the definition.
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 508 |
+
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 509 |
+
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
| 510 |
+
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
| 511 |
+
temperature: Softmax temperature (default: 1.0)
|
| 512 |
+
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
loss: Scalar tensor with the generalized JSD loss
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
# Apply temperature scaling
|
| 519 |
+
student_logits = student_logits / temperature
|
| 520 |
+
teacher_logits = teacher_logits / temperature
|
| 521 |
+
|
| 522 |
+
# Compute log probabilities for student and probabilities for teacher
|
| 523 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 524 |
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 525 |
+
|
| 526 |
+
# Compute the log of the mixture distribution
|
| 527 |
+
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
| 528 |
+
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
| 529 |
+
mixture_log_probs = torch.logsumexp(
|
| 530 |
+
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
|
| 531 |
+
dim=0,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Compute KL divergences using F.kl_div
|
| 535 |
+
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
| 536 |
+
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 537 |
+
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 538 |
+
|
| 539 |
+
# Compute the Generalized Jensen-Shannon Divergence
|
| 540 |
+
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
| 541 |
+
|
| 542 |
+
# Masking
|
| 543 |
+
if labels is not None:
|
| 544 |
+
mask = labels != -100
|
| 545 |
+
jsd = jsd[mask]
|
| 546 |
+
|
| 547 |
+
# Apply reduction
|
| 548 |
+
if reduction == "batchmean":
|
| 549 |
+
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
| 550 |
+
elif reduction == "sum":
|
| 551 |
+
return jsd.sum()
|
| 552 |
+
elif reduction == "mean":
|
| 553 |
+
return jsd.mean()
|
| 554 |
+
else:
|
| 555 |
+
return jsd
|
| 556 |
+
|
| 557 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 558 |
+
# compute student output
|
| 559 |
+
outputs_student = model(
|
| 560 |
+
input_ids=inputs["input_ids"],
|
| 561 |
+
attention_mask=inputs["attention_mask"],
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# compute teacher output in eval mode
|
| 565 |
+
self.teacher_model.eval()
|
| 566 |
+
with torch.no_grad():
|
| 567 |
+
outputs_teacher = self.teacher_model(
|
| 568 |
+
input_ids=inputs["input_ids"],
|
| 569 |
+
attention_mask=inputs["attention_mask"],
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
| 573 |
+
prompt_lengths = inputs["prompts"].shape[1]
|
| 574 |
+
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
| 575 |
+
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
| 576 |
+
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
| 577 |
+
|
| 578 |
+
# compute loss
|
| 579 |
+
loss = self.generalized_jsd_loss(
|
| 580 |
+
student_logits=shifted_student_logits,
|
| 581 |
+
teacher_logits=shifted_teacher_logits,
|
| 582 |
+
labels=shifted_labels,
|
| 583 |
+
beta=self.beta,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# empty cache
|
| 587 |
+
empty_cache()
|
| 588 |
+
|
| 589 |
+
# Return loss
|
| 590 |
+
return (loss, outputs_student) if return_outputs else loss
|
| 591 |
+
|
| 592 |
+
@staticmethod
|
| 593 |
+
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
| 594 |
+
# Generate output with respect to the prompt only
|
| 595 |
+
generated_outputs = model.generate(
|
| 596 |
+
input_ids=inputs["prompts"],
|
| 597 |
+
attention_mask=inputs.get("prompt_attention_mask", None),
|
| 598 |
+
generation_config=generation_config,
|
| 599 |
+
return_dict_in_generate=True,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Get the generated token IDs
|
| 603 |
+
generated_tokens = generated_outputs.sequences
|
| 604 |
+
# Calculate new attention mask
|
| 605 |
+
new_attention_mask = torch.ones_like(generated_tokens)
|
| 606 |
+
new_labels = generated_tokens.clone()
|
| 607 |
+
|
| 608 |
+
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
| 609 |
+
if pad_token_id is not None:
|
| 610 |
+
new_labels[new_labels == pad_token_id] = -100
|
| 611 |
+
new_attention_mask[generated_tokens == pad_token_id] = 0
|
| 612 |
+
|
| 613 |
+
return generated_tokens, new_attention_mask, new_labels
|
| 614 |
+
|
| 615 |
+
def training_step(
|
| 616 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 617 |
+
) -> torch.Tensor:
|
| 618 |
+
"""
|
| 619 |
+
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
| 620 |
+
|
| 621 |
+
This method implements the on-policy learning approach described in the GKD paper.
|
| 622 |
+
With probability `self.lmbda`, it generates new responses using the student model,
|
| 623 |
+
which are then used for training instead of the original inputs.
|
| 624 |
+
"""
|
| 625 |
+
if self.seq_kd:
|
| 626 |
+
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
| 627 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 628 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 629 |
+
)
|
| 630 |
+
inputs["input_ids"] = new_input_ids
|
| 631 |
+
inputs["attention_mask"] = new_attention_mask
|
| 632 |
+
inputs["labels"] = new_labels
|
| 633 |
+
if random.random() <= self.lmbda:
|
| 634 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 635 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 636 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 637 |
+
)
|
| 638 |
+
inputs["input_ids"] = new_input_ids
|
| 639 |
+
inputs["attention_mask"] = new_attention_mask
|
| 640 |
+
inputs["labels"] = new_labels
|
| 641 |
+
|
| 642 |
+
loss = super().training_step(model, inputs, num_items_in_batch)
|
| 643 |
+
return loss
|
| 644 |
+
|
| 645 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 646 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 647 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 648 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 649 |
+
|
| 650 |
+
if model is not None:
|
| 651 |
+
if hasattr(model, "config"):
|
| 652 |
+
hidden_size = (
|
| 653 |
+
max(model.config.hidden_sizes)
|
| 654 |
+
if getattr(model.config, "hidden_sizes", None)
|
| 655 |
+
else getattr(model.config, "hidden_size", None)
|
| 656 |
+
)
|
| 657 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 658 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 659 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 660 |
+
config_kwargs.update(
|
| 661 |
+
{
|
| 662 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 663 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 664 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 665 |
+
}
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 669 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 670 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 671 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
| 672 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 673 |
+
model.eval()
|
| 674 |
+
return model
|
| 675 |
+
|
| 676 |
+
def create_model_card(
|
| 677 |
+
self,
|
| 678 |
+
model_name: Optional[str] = None,
|
| 679 |
+
dataset_name: Optional[str] = None,
|
| 680 |
+
tags: Union[str, list[str], None] = None,
|
| 681 |
+
):
|
| 682 |
+
"""
|
| 683 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 684 |
+
|
| 685 |
+
Args:
|
| 686 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 687 |
+
Name of the model.
|
| 688 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 689 |
+
Name of the dataset used for training.
|
| 690 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 691 |
+
Tags to be associated with the model card.
|
| 692 |
+
"""
|
| 693 |
+
if not self.is_world_process_zero():
|
| 694 |
+
return
|
| 695 |
+
|
| 696 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 697 |
+
base_model = self.model.config._name_or_path
|
| 698 |
+
else:
|
| 699 |
+
base_model = None
|
| 700 |
+
|
| 701 |
+
tags = tags or []
|
| 702 |
+
if isinstance(tags, str):
|
| 703 |
+
tags = [tags]
|
| 704 |
+
|
| 705 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 706 |
+
tags.append("unsloth")
|
| 707 |
+
|
| 708 |
+
citation = textwrap.dedent("""\
|
| 709 |
+
@inproceedings{agarwal2024on-policy,
|
| 710 |
+
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
| 711 |
+
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
| 712 |
+
year = 2024,
|
| 713 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
| 714 |
+
publisher = {OpenReview.net},
|
| 715 |
+
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
| 716 |
+
}""")
|
| 717 |
+
|
| 718 |
+
model_card = generate_model_card(
|
| 719 |
+
base_model=base_model,
|
| 720 |
+
model_name=model_name,
|
| 721 |
+
hub_model_id=self.hub_model_id,
|
| 722 |
+
dataset_name=dataset_name,
|
| 723 |
+
tags=tags,
|
| 724 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 725 |
+
comet_url=get_comet_experiment_url(),
|
| 726 |
+
trainer_name="GKD",
|
| 727 |
+
trainer_citation=citation,
|
| 728 |
+
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
| 729 |
+
paper_id="2306.13649",
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 733 |
+
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
| 734 |
+
"""
|
| 735 |
+
|
| 736 |
+
"""
|
| 737 |
+
def __init__(
|
| 738 |
+
self,
|
| 739 |
+
model = None,
|
| 740 |
+
teacher_model = None,
|
| 741 |
+
args = None,
|
| 742 |
+
data_collator = None,
|
| 743 |
+
train_dataset = None,
|
| 744 |
+
eval_dataset = None,
|
| 745 |
+
processing_class = None,
|
| 746 |
+
compute_metrics = None,
|
| 747 |
+
callbacks = None,
|
| 748 |
+
preprocess_logits_for_metrics = None,
|
| 749 |
+
peft_config = None,
|
| 750 |
+
formatting_func = None,
|
| 751 |
+
**kwargs
|
| 752 |
+
):
|
| 753 |
+
if args is None: args = UnslothGKDConfig()
|
| 754 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 755 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 756 |
+
force_float32 = False
|
| 757 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 758 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 759 |
+
force_float32 = True
|
| 760 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 761 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 762 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 763 |
+
from unsloth_zoo.utils import _get_dtype
|
| 764 |
+
dtype = _get_dtype(dtype)
|
| 765 |
+
float16 = dtype == torch.float16
|
| 766 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 767 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 768 |
+
if force_float32:
|
| 769 |
+
args.fp16 = False
|
| 770 |
+
args.bf16 = False
|
| 771 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 772 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 773 |
+
args.fp16 = float16
|
| 774 |
+
args.bf16 = not float16
|
| 775 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 776 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 777 |
+
args.eval_strategy = 'steps'
|
| 778 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 779 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 780 |
+
if ga_steps is not None and ga_steps > 1:
|
| 781 |
+
from transformers import __version__ as transformers_version
|
| 782 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 783 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 784 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 785 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 786 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 787 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 788 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 789 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 790 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 791 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 792 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 793 |
+
if force_float32:
|
| 794 |
+
args.bf16_full_eval = False
|
| 795 |
+
args.fp16_full_eval = False
|
| 796 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 797 |
+
args.bf16_full_eval = True
|
| 798 |
+
args.fp16_full_eval = False
|
| 799 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 800 |
+
args.bf16_full_eval = args.bf16
|
| 801 |
+
args.fp16_full_eval = args.fp16
|
| 802 |
+
_output_logits = False
|
| 803 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 804 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 805 |
+
if _output_logits:
|
| 806 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 807 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 808 |
+
pass
|
| 809 |
+
else:
|
| 810 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 811 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 812 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 813 |
+
max_seq_length = model.max_seq_length
|
| 814 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 815 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 816 |
+
model.for_training()
|
| 817 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 818 |
+
if 'processing_class' in locals():
|
| 819 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 820 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 821 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 822 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 823 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 824 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 825 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 826 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 827 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 828 |
+
else:
|
| 829 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 830 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 831 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 832 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 833 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 834 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 835 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 836 |
+
else:
|
| 837 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 838 |
+
other_metrics = []
|
| 839 |
+
|
| 840 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 841 |
+
PatchRLStatistics('gkd_trainer', other_metrics)
|
| 842 |
+
|
| 843 |
+
super().__init__(
|
| 844 |
+
model = model,
|
| 845 |
+
teacher_model = teacher_model,
|
| 846 |
+
args = args,
|
| 847 |
+
data_collator = data_collator,
|
| 848 |
+
train_dataset = train_dataset,
|
| 849 |
+
eval_dataset = eval_dataset,
|
| 850 |
+
processing_class = processing_class,
|
| 851 |
+
compute_metrics = compute_metrics,
|
| 852 |
+
callbacks = callbacks,
|
| 853 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 854 |
+
peft_config = peft_config,
|
| 855 |
+
formatting_func = formatting_func,**kwargs)
|
| 856 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 857 |
+
self.neftune_hook_handle.remove()
|
| 858 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 859 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 860 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 861 |
+
pass
|
| 862 |
+
|
| 863 |
+
pass
|
unsloth_compiled_cache/UnslothGRPOTrainer.py
ADDED
|
@@ -0,0 +1,1438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
|
| 43 |
+
def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
|
| 44 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 45 |
+
old_logits = old_logits.to(torch.float32)
|
| 46 |
+
new_logits = new_logits.to(torch.float32)
|
| 47 |
+
input_ids = input_ids.unsqueeze(-1)
|
| 48 |
+
|
| 49 |
+
# x_i - logsumexp(x_i)
|
| 50 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
| 51 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
| 52 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
| 53 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
| 54 |
+
|
| 55 |
+
# Reverse KL
|
| 56 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
| 57 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
| 58 |
+
# kl_i = torch.exp(new) * kl_i
|
| 59 |
+
|
| 60 |
+
# Below is forward KL (normal KL)
|
| 61 |
+
# kl_i = torch.exp(old) * (old - new)
|
| 62 |
+
|
| 63 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
| 64 |
+
# exp(x - x) == 1
|
| 65 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
| 66 |
+
loss_i = -(loss_i - beta * kl_i)
|
| 67 |
+
|
| 68 |
+
mask = mask.to(torch.float32)
|
| 69 |
+
n_mask_per_reward = mask.sum(1)
|
| 70 |
+
|
| 71 |
+
# See https://github.com/huggingface/trl/pull/2881
|
| 72 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
| 73 |
+
loss = loss_per_reward.mean()
|
| 74 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
| 75 |
+
|
| 76 |
+
# Get metrics as well which are folded
|
| 77 |
+
with torch.inference_mode():
|
| 78 |
+
completion_length = n_mask_per_reward.mean()
|
| 79 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
| 80 |
+
mean_kl = mean_kl_per_reward.mean()
|
| 81 |
+
pass
|
| 82 |
+
return loss, completion_length, mean_kl
|
| 83 |
+
|
| 84 |
+
class UnslothEfficientGRPO(torch.autograd.Function):
|
| 85 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 86 |
+
@staticmethod
|
| 87 |
+
def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
|
| 88 |
+
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
|
| 89 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
| 90 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
| 91 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
| 92 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
| 93 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
| 94 |
+
old_logits, new_logits, input_ids, mask, beta, advantages,
|
| 95 |
+
)
|
| 96 |
+
# Scale loss if needed for mixed precision training
|
| 97 |
+
scaled_loss = loss * scaling
|
| 98 |
+
# Must add .loss.detach otherwise autograd uses 2x VRAM
|
| 99 |
+
return scaled_loss, (loss.detach(), completion_length, mean_kl,)
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
device =_new_hidden_states.device
|
| 103 |
+
grad_inputs = torch.empty_like(_new_hidden_states)
|
| 104 |
+
accumulated_loss = torch.zeros(1, device = device)
|
| 105 |
+
accumulated_completion_length = torch.zeros(1, device = device)
|
| 106 |
+
accumulated_mean_kl = torch.zeros(1, device = device)
|
| 107 |
+
|
| 108 |
+
def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
|
| 109 |
+
(chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
|
| 110 |
+
compute_loss,
|
| 111 |
+
argnums = (0,),
|
| 112 |
+
has_aux = True,
|
| 113 |
+
)(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
| 114 |
+
accumulated_loss .add_(unscaled_loss)
|
| 115 |
+
accumulated_completion_length.add_(chunk_completion_length)
|
| 116 |
+
accumulated_mean_kl .add_(chunk_mean_kl)
|
| 117 |
+
return chunk_grad_input
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
accumulate_chunk = torch.compile(
|
| 121 |
+
accumulate_chunk,
|
| 122 |
+
fullgraph = True,
|
| 123 |
+
options = torch_compile_options,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
|
| 127 |
+
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
|
| 128 |
+
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
|
| 129 |
+
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
|
| 130 |
+
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
|
| 131 |
+
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
|
| 132 |
+
|
| 133 |
+
# Get mixed precision scaling if seen
|
| 134 |
+
scaling = scaler.get_scale() if scaler is not None else 1.0
|
| 135 |
+
|
| 136 |
+
# Force torch.compile to use dynamic shapes for seqlen dim
|
| 137 |
+
mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
|
| 138 |
+
|
| 139 |
+
for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
|
| 140 |
+
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
|
| 141 |
+
|
| 142 |
+
mark_dynamic(new_hidden_states_j)
|
| 143 |
+
mark_dynamic(old_hidden_states_j)
|
| 144 |
+
mark_dynamic(input_ids_j)
|
| 145 |
+
mark_dynamic(mask_j)
|
| 146 |
+
|
| 147 |
+
grad_inputs_j.copy_(
|
| 148 |
+
accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
| 149 |
+
)
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
grad_inputs .div_(n_chunks)
|
| 153 |
+
accumulated_loss .div_(n_chunks)
|
| 154 |
+
accumulated_completion_length.div_(n_chunks)
|
| 155 |
+
accumulated_mean_kl .div_(n_chunks)
|
| 156 |
+
ctx.save_for_backward(grad_inputs)
|
| 157 |
+
|
| 158 |
+
return (
|
| 159 |
+
accumulated_loss,
|
| 160 |
+
accumulated_completion_length,
|
| 161 |
+
accumulated_mean_kl,
|
| 162 |
+
)
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def backward(ctx, grad_output, dcompletion_length, dmean_kl):
|
| 167 |
+
(grad_input,) = ctx.saved_tensors
|
| 168 |
+
return (grad_input, None, None, None, None, None, None, None, None,)
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
def grpo_accumulated_loss(
|
| 172 |
+
trainer,
|
| 173 |
+
input_ids,
|
| 174 |
+
logits_to_keep,
|
| 175 |
+
completion_mask,
|
| 176 |
+
advantages,
|
| 177 |
+
n_chunks = -1,
|
| 178 |
+
):
|
| 179 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 180 |
+
bsz, qlen = input_ids.shape
|
| 181 |
+
# Find closest multiple
|
| 182 |
+
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
|
| 183 |
+
if n_chunks == -1: n_chunks = bsz
|
| 184 |
+
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
|
| 185 |
+
|
| 186 |
+
mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
| 187 |
+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
| 188 |
+
|
| 189 |
+
completion_input_ids = input_ids[:, -logits_to_keep:]
|
| 190 |
+
lm_head = trainer.model.get_output_embeddings().weight
|
| 191 |
+
|
| 192 |
+
with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
|
| 193 |
+
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
|
| 194 |
+
old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
| 195 |
+
pass
|
| 196 |
+
|
| 197 |
+
new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
| 198 |
+
|
| 199 |
+
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
| 200 |
+
new_hidden_states, old_hidden_states, lm_head,
|
| 201 |
+
completion_input_ids, completion_mask, advantages, trainer.beta,
|
| 202 |
+
trainer.accelerator.scaler,
|
| 203 |
+
n_chunks,
|
| 204 |
+
)
|
| 205 |
+
return loss, completion_length, mean_kl
|
| 206 |
+
|
| 207 |
+
# Old non efficient code path
|
| 208 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
| 209 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
| 210 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
| 211 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
| 212 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
| 213 |
+
old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
|
| 214 |
+
)
|
| 215 |
+
return loss, completion_length, mean_kl
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
|
| 219 |
+
def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
|
| 220 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 221 |
+
old_logits = old_logits.to(torch.float32)
|
| 222 |
+
new_logits = new_logits.to(torch.float32)
|
| 223 |
+
input_ids = input_ids.unsqueeze(-1)
|
| 224 |
+
|
| 225 |
+
# x_i - logsumexp(x_i)
|
| 226 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
| 227 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
| 228 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
| 229 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
| 230 |
+
|
| 231 |
+
# Reverse KL
|
| 232 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
| 233 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
| 234 |
+
# kl_i = torch.exp(new) * kl_i
|
| 235 |
+
|
| 236 |
+
# Below is forward KL (normal KL)
|
| 237 |
+
# kl_i = torch.exp(old) * (old - new)
|
| 238 |
+
|
| 239 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
| 240 |
+
# exp(x - x) == 1
|
| 241 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
| 242 |
+
loss_i = -(loss_i - beta * kl_i)
|
| 243 |
+
|
| 244 |
+
mask = mask.to(torch.float32)
|
| 245 |
+
n_mask_per_reward = mask.sum(1)
|
| 246 |
+
|
| 247 |
+
# See https://github.com/huggingface/trl/pull/2881
|
| 248 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
| 249 |
+
loss = loss_per_reward.mean()
|
| 250 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
| 251 |
+
|
| 252 |
+
# Get metrics as well which are folded
|
| 253 |
+
with torch.inference_mode():
|
| 254 |
+
completion_length = n_mask_per_reward.mean()
|
| 255 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
| 256 |
+
mean_kl = mean_kl_per_reward.mean()
|
| 257 |
+
pass
|
| 258 |
+
return loss, completion_length, mean_kl
|
| 259 |
+
|
| 260 |
+
def vLLMSamplingParams(**kwargs):
|
| 261 |
+
from vllm import SamplingParams
|
| 262 |
+
sampling_params = SamplingParams(**kwargs)
|
| 263 |
+
sampling_params._set_kwargs = kwargs
|
| 264 |
+
return sampling_params
|
| 265 |
+
@dataclass
|
| 266 |
+
class UnslothGRPOConfig(GRPOConfig):
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
Configuration class for the [`GRPOTrainer`].
|
| 270 |
+
|
| 271 |
+
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
| 272 |
+
[`~transformers.TrainingArguments`] documentation.
|
| 273 |
+
|
| 274 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 275 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 276 |
+
command line.
|
| 277 |
+
|
| 278 |
+
Parameters:
|
| 279 |
+
> Parameters that control the model and reference model
|
| 280 |
+
|
| 281 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 282 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 283 |
+
argument of the [`GRPOTrainer`] is provided as a string.
|
| 284 |
+
|
| 285 |
+
> Parameters that control the data preprocessing
|
| 286 |
+
|
| 287 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
| 288 |
+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
| 289 |
+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
| 290 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 291 |
+
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
| 292 |
+
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
| 293 |
+
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
| 294 |
+
must be divisible by this value.
|
| 295 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 296 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 297 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
| 298 |
+
Maximum length of the generated completion.
|
| 299 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 300 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 301 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 302 |
+
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
| 303 |
+
with vLLM generation.
|
| 304 |
+
|
| 305 |
+
> Parameters that control generation acceleration powered by vLLM
|
| 306 |
+
|
| 307 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 308 |
+
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
| 309 |
+
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
| 310 |
+
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
| 311 |
+
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
| 312 |
+
automatically select the next available GPU after the last one used for training. This assumes that
|
| 313 |
+
training has not already occupied all available GPUs. If only one device is available, the device will be
|
| 314 |
+
shared between both training and vLLM.
|
| 315 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
| 316 |
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
| 317 |
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
| 318 |
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
| 319 |
+
during initialization.
|
| 320 |
+
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
| 321 |
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
| 322 |
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
| 323 |
+
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
| 324 |
+
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
| 325 |
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
| 326 |
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
| 327 |
+
|
| 328 |
+
> Parameters that control the training
|
| 329 |
+
|
| 330 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 331 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 332 |
+
[`~transformers.TrainingArguments`].
|
| 333 |
+
beta (`float`, *optional*, defaults to `0.04`):
|
| 334 |
+
KL coefficient.
|
| 335 |
+
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
| 336 |
+
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
| 337 |
+
weighted equally with weight `1.0`.
|
| 338 |
+
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
| 339 |
+
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
| 340 |
+
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
| 341 |
+
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
| 342 |
+
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
|
| 343 |
+
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
| 344 |
+
between the current policy and the previous reference policy during updates. The reference policy is
|
| 345 |
+
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
| 346 |
+
must set `sync_ref_model=True`.
|
| 347 |
+
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
|
| 348 |
+
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
| 349 |
+
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
| 350 |
+
set `sync_ref_model=True`.
|
| 351 |
+
|
| 352 |
+
> Parameters that control the logging
|
| 353 |
+
|
| 354 |
+
log_completions (`bool`, *optional*, defaults to `False`):
|
| 355 |
+
Whether to log the completions during training.
|
| 356 |
+
|
| 357 |
+
"""
|
| 358 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 359 |
+
default = None,
|
| 360 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 361 |
+
)
|
| 362 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 363 |
+
default = -1,
|
| 364 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 365 |
+
)
|
| 366 |
+
def __init__(
|
| 367 |
+
self,
|
| 368 |
+
output_dir = None,
|
| 369 |
+
overwrite_output_dir = None,
|
| 370 |
+
do_train = False,
|
| 371 |
+
do_eval = False,
|
| 372 |
+
do_predict = False,
|
| 373 |
+
eval_strategy = 'no',
|
| 374 |
+
prediction_loss_only = False,
|
| 375 |
+
per_device_train_batch_size = 4,
|
| 376 |
+
per_device_eval_batch_size = 4,
|
| 377 |
+
per_gpu_train_batch_size = None,
|
| 378 |
+
per_gpu_eval_batch_size = None,
|
| 379 |
+
gradient_accumulation_steps = 2,
|
| 380 |
+
eval_accumulation_steps = 2,
|
| 381 |
+
eval_delay = 0,
|
| 382 |
+
torch_empty_cache_steps = 250,
|
| 383 |
+
learning_rate = 5e-05,
|
| 384 |
+
weight_decay = 0.01,
|
| 385 |
+
adam_beta1 = 0.9,
|
| 386 |
+
adam_beta2 = 0.999,
|
| 387 |
+
adam_epsilon = 1e-08,
|
| 388 |
+
max_grad_norm = 1.0,
|
| 389 |
+
num_train_epochs = 3.0,
|
| 390 |
+
max_steps = -1,
|
| 391 |
+
lr_scheduler_type = 'linear',
|
| 392 |
+
warmup_ratio = 0.1,
|
| 393 |
+
warmup_steps = 0,
|
| 394 |
+
log_level = 'passive',
|
| 395 |
+
log_level_replica = 'warning',
|
| 396 |
+
log_on_each_node = True,
|
| 397 |
+
logging_dir = None,
|
| 398 |
+
logging_strategy = 'steps',
|
| 399 |
+
logging_first_step = False,
|
| 400 |
+
logging_steps = 1,
|
| 401 |
+
logging_nan_inf_filter = False,
|
| 402 |
+
save_strategy = 'steps',
|
| 403 |
+
save_steps = 500,
|
| 404 |
+
save_total_limit = None,
|
| 405 |
+
save_safetensors = True,
|
| 406 |
+
save_on_each_node = False,
|
| 407 |
+
save_only_model = False,
|
| 408 |
+
restore_callback_states_from_checkpoint = False,
|
| 409 |
+
no_cuda = False,
|
| 410 |
+
use_cpu = False,
|
| 411 |
+
use_mps_device = False,
|
| 412 |
+
seed = 3407,
|
| 413 |
+
data_seed = 3407,
|
| 414 |
+
jit_mode_eval = False,
|
| 415 |
+
use_ipex = False,
|
| 416 |
+
bf16 = False,
|
| 417 |
+
fp16 = False,
|
| 418 |
+
fp16_opt_level = 'O1',
|
| 419 |
+
half_precision_backend = 'auto',
|
| 420 |
+
bf16_full_eval = False,
|
| 421 |
+
fp16_full_eval = False,
|
| 422 |
+
tf32 = None,
|
| 423 |
+
local_rank = -1,
|
| 424 |
+
ddp_backend = None,
|
| 425 |
+
tpu_num_cores = None,
|
| 426 |
+
tpu_metrics_debug = False,
|
| 427 |
+
debug = '',
|
| 428 |
+
dataloader_drop_last = False,
|
| 429 |
+
eval_steps = None,
|
| 430 |
+
dataloader_num_workers = 0,
|
| 431 |
+
dataloader_prefetch_factor = None,
|
| 432 |
+
past_index = -1,
|
| 433 |
+
run_name = None,
|
| 434 |
+
disable_tqdm = None,
|
| 435 |
+
remove_unused_columns = False,
|
| 436 |
+
label_names = None,
|
| 437 |
+
load_best_model_at_end = False,
|
| 438 |
+
metric_for_best_model = None,
|
| 439 |
+
greater_is_better = None,
|
| 440 |
+
ignore_data_skip = False,
|
| 441 |
+
fsdp = '',
|
| 442 |
+
fsdp_min_num_params = 0,
|
| 443 |
+
fsdp_config = None,
|
| 444 |
+
tp_size = 0,
|
| 445 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 446 |
+
accelerator_config = None,
|
| 447 |
+
deepspeed = None,
|
| 448 |
+
label_smoothing_factor = 0.0,
|
| 449 |
+
optim = 'adamw_8bit',
|
| 450 |
+
optim_args = None,
|
| 451 |
+
adafactor = False,
|
| 452 |
+
group_by_length = False,
|
| 453 |
+
length_column_name = 'length',
|
| 454 |
+
report_to = None,
|
| 455 |
+
ddp_find_unused_parameters = None,
|
| 456 |
+
ddp_bucket_cap_mb = None,
|
| 457 |
+
ddp_broadcast_buffers = None,
|
| 458 |
+
dataloader_pin_memory = True,
|
| 459 |
+
dataloader_persistent_workers = False,
|
| 460 |
+
skip_memory_metrics = True,
|
| 461 |
+
use_legacy_prediction_loop = False,
|
| 462 |
+
push_to_hub = False,
|
| 463 |
+
resume_from_checkpoint = None,
|
| 464 |
+
hub_model_id = None,
|
| 465 |
+
hub_strategy = 'every_save',
|
| 466 |
+
hub_token = None,
|
| 467 |
+
hub_private_repo = None,
|
| 468 |
+
hub_always_push = False,
|
| 469 |
+
gradient_checkpointing = False,
|
| 470 |
+
gradient_checkpointing_kwargs = None,
|
| 471 |
+
include_inputs_for_metrics = False,
|
| 472 |
+
eval_do_concat_batches = True,
|
| 473 |
+
fp16_backend = 'auto',
|
| 474 |
+
evaluation_strategy = None,
|
| 475 |
+
push_to_hub_model_id = None,
|
| 476 |
+
push_to_hub_organization = None,
|
| 477 |
+
push_to_hub_token = None,
|
| 478 |
+
mp_parameters = '',
|
| 479 |
+
auto_find_batch_size = False,
|
| 480 |
+
full_determinism = False,
|
| 481 |
+
torchdynamo = None,
|
| 482 |
+
ray_scope = 'last',
|
| 483 |
+
ddp_timeout = 1800,
|
| 484 |
+
torch_compile = False,
|
| 485 |
+
torch_compile_backend = None,
|
| 486 |
+
torch_compile_mode = None,
|
| 487 |
+
dispatch_batches = None,
|
| 488 |
+
split_batches = None,
|
| 489 |
+
include_tokens_per_second = False,
|
| 490 |
+
include_num_input_tokens_seen = False,
|
| 491 |
+
neftune_noise_alpha = None,
|
| 492 |
+
optim_target_modules = None,
|
| 493 |
+
batch_eval_metrics = False,
|
| 494 |
+
eval_on_start = False,
|
| 495 |
+
use_liger_kernel = False,
|
| 496 |
+
eval_use_gather_object = False,
|
| 497 |
+
average_tokens_across_devices = False,
|
| 498 |
+
model_init_kwargs = None,
|
| 499 |
+
max_prompt_length = 512,
|
| 500 |
+
num_generations = 8,
|
| 501 |
+
temperature = 0.9,
|
| 502 |
+
max_completion_length = 256,
|
| 503 |
+
ds3_gather_for_generation = True,
|
| 504 |
+
use_vllm = False,
|
| 505 |
+
vllm_device = 'auto',
|
| 506 |
+
vllm_gpu_memory_utilization = 0.9,
|
| 507 |
+
vllm_dtype = 'auto',
|
| 508 |
+
vllm_max_model_len = None,
|
| 509 |
+
beta = 0.04,
|
| 510 |
+
reward_weights = None,
|
| 511 |
+
sync_ref_model = False,
|
| 512 |
+
ref_model_mixup_alpha = 0.9,
|
| 513 |
+
ref_model_sync_steps = 64,
|
| 514 |
+
log_completions = False,
|
| 515 |
+
vllm_sampling_params = None,
|
| 516 |
+
unsloth_num_chunks = -1,
|
| 517 |
+
**kwargs,
|
| 518 |
+
):
|
| 519 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 520 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 521 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 522 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 523 |
+
save_strategy = 'no'
|
| 524 |
+
div = per_device_train_batch_size // num_generations
|
| 525 |
+
if div * num_generations != per_device_train_batch_size:
|
| 526 |
+
print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
|
| 527 |
+
per_device_train_batch_size = num_generations
|
| 528 |
+
|
| 529 |
+
super().__init__(
|
| 530 |
+
output_dir = output_dir,
|
| 531 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 532 |
+
do_train = do_train,
|
| 533 |
+
do_eval = do_eval,
|
| 534 |
+
do_predict = do_predict,
|
| 535 |
+
eval_strategy = eval_strategy,
|
| 536 |
+
prediction_loss_only = prediction_loss_only,
|
| 537 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 538 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 539 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 540 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 541 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 542 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 543 |
+
eval_delay = eval_delay,
|
| 544 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 545 |
+
learning_rate = learning_rate,
|
| 546 |
+
weight_decay = weight_decay,
|
| 547 |
+
adam_beta1 = adam_beta1,
|
| 548 |
+
adam_beta2 = adam_beta2,
|
| 549 |
+
adam_epsilon = adam_epsilon,
|
| 550 |
+
max_grad_norm = max_grad_norm,
|
| 551 |
+
num_train_epochs = num_train_epochs,
|
| 552 |
+
max_steps = max_steps,
|
| 553 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 554 |
+
warmup_ratio = warmup_ratio,
|
| 555 |
+
warmup_steps = warmup_steps,
|
| 556 |
+
log_level = log_level,
|
| 557 |
+
log_level_replica = log_level_replica,
|
| 558 |
+
log_on_each_node = log_on_each_node,
|
| 559 |
+
logging_dir = logging_dir,
|
| 560 |
+
logging_strategy = logging_strategy,
|
| 561 |
+
logging_first_step = logging_first_step,
|
| 562 |
+
logging_steps = logging_steps,
|
| 563 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 564 |
+
save_strategy = save_strategy,
|
| 565 |
+
save_steps = save_steps,
|
| 566 |
+
save_total_limit = save_total_limit,
|
| 567 |
+
save_safetensors = save_safetensors,
|
| 568 |
+
save_on_each_node = save_on_each_node,
|
| 569 |
+
save_only_model = save_only_model,
|
| 570 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 571 |
+
no_cuda = no_cuda,
|
| 572 |
+
use_cpu = use_cpu,
|
| 573 |
+
use_mps_device = use_mps_device,
|
| 574 |
+
seed = seed,
|
| 575 |
+
data_seed = data_seed,
|
| 576 |
+
jit_mode_eval = jit_mode_eval,
|
| 577 |
+
use_ipex = use_ipex,
|
| 578 |
+
bf16 = bf16,
|
| 579 |
+
fp16 = fp16,
|
| 580 |
+
fp16_opt_level = fp16_opt_level,
|
| 581 |
+
half_precision_backend = half_precision_backend,
|
| 582 |
+
bf16_full_eval = bf16_full_eval,
|
| 583 |
+
fp16_full_eval = fp16_full_eval,
|
| 584 |
+
tf32 = tf32,
|
| 585 |
+
local_rank = local_rank,
|
| 586 |
+
ddp_backend = ddp_backend,
|
| 587 |
+
tpu_num_cores = tpu_num_cores,
|
| 588 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 589 |
+
debug = debug,
|
| 590 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 591 |
+
eval_steps = eval_steps,
|
| 592 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 593 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 594 |
+
past_index = past_index,
|
| 595 |
+
run_name = run_name,
|
| 596 |
+
disable_tqdm = disable_tqdm,
|
| 597 |
+
remove_unused_columns = remove_unused_columns,
|
| 598 |
+
label_names = label_names,
|
| 599 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 600 |
+
metric_for_best_model = metric_for_best_model,
|
| 601 |
+
greater_is_better = greater_is_better,
|
| 602 |
+
ignore_data_skip = ignore_data_skip,
|
| 603 |
+
fsdp = fsdp,
|
| 604 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 605 |
+
fsdp_config = fsdp_config,
|
| 606 |
+
tp_size = tp_size,
|
| 607 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 608 |
+
accelerator_config = accelerator_config,
|
| 609 |
+
deepspeed = deepspeed,
|
| 610 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 611 |
+
optim = optim,
|
| 612 |
+
optim_args = optim_args,
|
| 613 |
+
adafactor = adafactor,
|
| 614 |
+
group_by_length = group_by_length,
|
| 615 |
+
length_column_name = length_column_name,
|
| 616 |
+
report_to = report_to,
|
| 617 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 618 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 619 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 620 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 621 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 622 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 623 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 624 |
+
push_to_hub = push_to_hub,
|
| 625 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 626 |
+
hub_model_id = hub_model_id,
|
| 627 |
+
hub_strategy = hub_strategy,
|
| 628 |
+
hub_token = hub_token,
|
| 629 |
+
hub_private_repo = hub_private_repo,
|
| 630 |
+
hub_always_push = hub_always_push,
|
| 631 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 632 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 633 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 634 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 635 |
+
fp16_backend = fp16_backend,
|
| 636 |
+
evaluation_strategy = evaluation_strategy,
|
| 637 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 638 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 639 |
+
push_to_hub_token = push_to_hub_token,
|
| 640 |
+
mp_parameters = mp_parameters,
|
| 641 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 642 |
+
full_determinism = full_determinism,
|
| 643 |
+
torchdynamo = torchdynamo,
|
| 644 |
+
ray_scope = ray_scope,
|
| 645 |
+
ddp_timeout = ddp_timeout,
|
| 646 |
+
torch_compile = torch_compile,
|
| 647 |
+
torch_compile_backend = torch_compile_backend,
|
| 648 |
+
torch_compile_mode = torch_compile_mode,
|
| 649 |
+
dispatch_batches = dispatch_batches,
|
| 650 |
+
split_batches = split_batches,
|
| 651 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 652 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 653 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 654 |
+
optim_target_modules = optim_target_modules,
|
| 655 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 656 |
+
eval_on_start = eval_on_start,
|
| 657 |
+
use_liger_kernel = use_liger_kernel,
|
| 658 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 659 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 660 |
+
model_init_kwargs = model_init_kwargs,
|
| 661 |
+
max_prompt_length = max_prompt_length,
|
| 662 |
+
num_generations = num_generations,
|
| 663 |
+
temperature = temperature,
|
| 664 |
+
max_completion_length = max_completion_length,
|
| 665 |
+
ds3_gather_for_generation = ds3_gather_for_generation,
|
| 666 |
+
use_vllm = use_vllm,
|
| 667 |
+
vllm_device = vllm_device,
|
| 668 |
+
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
| 669 |
+
vllm_dtype = vllm_dtype,
|
| 670 |
+
vllm_max_model_len = vllm_max_model_len,
|
| 671 |
+
beta = beta,
|
| 672 |
+
reward_weights = reward_weights,
|
| 673 |
+
sync_ref_model = sync_ref_model,
|
| 674 |
+
ref_model_mixup_alpha = ref_model_mixup_alpha,
|
| 675 |
+
ref_model_sync_steps = ref_model_sync_steps,
|
| 676 |
+
log_completions = log_completions,**kwargs)
|
| 677 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 678 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 679 |
+
pass
|
| 680 |
+
|
| 681 |
+
class _UnslothGRPOTrainer(Trainer):
|
| 682 |
+
""""""
|
| 683 |
+
|
| 684 |
+
_tag_names = ["trl", "grpo"]
|
| 685 |
+
|
| 686 |
+
def __init__(
|
| 687 |
+
self,
|
| 688 |
+
model: Union[str, PreTrainedModel],
|
| 689 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 690 |
+
args: GRPOConfig = None,
|
| 691 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 692 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 693 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 694 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 695 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 696 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 697 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 698 |
+
):
|
| 699 |
+
|
| 700 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
| 701 |
+
# Args
|
| 702 |
+
if args is None:
|
| 703 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 704 |
+
model_name = model_name.split("/")[-1]
|
| 705 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 706 |
+
|
| 707 |
+
# Models
|
| 708 |
+
# Trained model
|
| 709 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 710 |
+
if isinstance(model, str):
|
| 711 |
+
model_id = model
|
| 712 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 713 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 714 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 715 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 716 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 717 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 718 |
+
else:
|
| 719 |
+
raise ValueError(
|
| 720 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 721 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 722 |
+
)
|
| 723 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 724 |
+
model_init_kwargs["use_cache"] = (
|
| 725 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 726 |
+
)
|
| 727 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 728 |
+
else:
|
| 729 |
+
model_id = model.config._name_or_path
|
| 730 |
+
if args.model_init_kwargs is not None:
|
| 731 |
+
raise ValueError(
|
| 732 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 733 |
+
"This argument can only be used when the `model` argument is a string."
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if False:
|
| 737 |
+
model = model
|
| 738 |
+
|
| 739 |
+
# Reference model
|
| 740 |
+
if is_deepspeed_zero3_enabled():
|
| 741 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
| 742 |
+
elif not is_peft_model(model):
|
| 743 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 744 |
+
self.ref_model = create_reference_model(model)
|
| 745 |
+
else:
|
| 746 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 747 |
+
# to revert to the initial model.
|
| 748 |
+
self.ref_model = None
|
| 749 |
+
|
| 750 |
+
# Processing class
|
| 751 |
+
if processing_class is None:
|
| 752 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
| 753 |
+
|
| 754 |
+
# Reward functions
|
| 755 |
+
if not isinstance(reward_funcs, list):
|
| 756 |
+
reward_funcs = [reward_funcs]
|
| 757 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 758 |
+
if isinstance(reward_func, str):
|
| 759 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 760 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 761 |
+
)
|
| 762 |
+
self.reward_funcs = reward_funcs
|
| 763 |
+
|
| 764 |
+
# Reward weights
|
| 765 |
+
if args.reward_weights is not None:
|
| 766 |
+
if len(args.reward_weights) != len(reward_funcs):
|
| 767 |
+
raise ValueError(
|
| 768 |
+
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
| 769 |
+
f"functions ({len(reward_funcs)})"
|
| 770 |
+
)
|
| 771 |
+
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
| 772 |
+
else:
|
| 773 |
+
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
| 774 |
+
|
| 775 |
+
# Reward processing class
|
| 776 |
+
if reward_processing_classes is None:
|
| 777 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 778 |
+
elif not isinstance(reward_processing_classes, list):
|
| 779 |
+
reward_processing_classes = [reward_processing_classes]
|
| 780 |
+
else:
|
| 781 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 782 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
| 783 |
+
|
| 784 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
| 785 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 786 |
+
if reward_processing_class is None:
|
| 787 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
| 788 |
+
if reward_processing_class.pad_token_id is None:
|
| 789 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
| 790 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 791 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 792 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 793 |
+
reward_processing_classes[i] = reward_processing_class
|
| 794 |
+
self.reward_processing_classes = reward_processing_classes
|
| 795 |
+
|
| 796 |
+
# Data collator
|
| 797 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 798 |
+
return features
|
| 799 |
+
|
| 800 |
+
# Training arguments
|
| 801 |
+
self.max_prompt_length = args.max_prompt_length
|
| 802 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 803 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 804 |
+
self.use_vllm = args.use_vllm
|
| 805 |
+
|
| 806 |
+
self.beta = args.beta
|
| 807 |
+
|
| 808 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 809 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 810 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 811 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 812 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 813 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 814 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 815 |
+
|
| 816 |
+
# Initialize the metrics
|
| 817 |
+
self._metrics = defaultdict(list)
|
| 818 |
+
self.log_completions = args.log_completions
|
| 819 |
+
|
| 820 |
+
super().__init__(
|
| 821 |
+
model=model,
|
| 822 |
+
args=args,
|
| 823 |
+
data_collator=data_collator,
|
| 824 |
+
train_dataset=train_dataset,
|
| 825 |
+
eval_dataset=eval_dataset,
|
| 826 |
+
processing_class=processing_class,
|
| 827 |
+
callbacks=callbacks,
|
| 828 |
+
optimizers=optimizers,
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 832 |
+
num_processes = self.accelerator.num_processes
|
| 833 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 834 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 835 |
+
if self.num_generations not in possible_values:
|
| 836 |
+
raise ValueError(
|
| 837 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 838 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 839 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 840 |
+
)
|
| 841 |
+
if self.args.eval_strategy != "no":
|
| 842 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 843 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 844 |
+
if self.num_generations not in possible_values:
|
| 845 |
+
raise ValueError(
|
| 846 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 847 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 848 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 852 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 853 |
+
# it's safer to set it in all cases.
|
| 854 |
+
set_seed(args.seed, device_specific=True)
|
| 855 |
+
|
| 856 |
+
if self.use_vllm:
|
| 857 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
|
| 858 |
+
temperature=args.temperature,
|
| 859 |
+
max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
| 860 |
+
else:
|
| 861 |
+
self.generation_config = GenerationConfig(
|
| 862 |
+
max_new_tokens=self.max_completion_length,
|
| 863 |
+
do_sample=True,
|
| 864 |
+
temperature=args.temperature,
|
| 865 |
+
pad_token_id=processing_class.pad_token_id,
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 869 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 870 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 871 |
+
self.model_accepts_loss_kwargs = False
|
| 872 |
+
|
| 873 |
+
# Add tags to the model
|
| 874 |
+
self.model.add_model_tags(self._tag_names)
|
| 875 |
+
|
| 876 |
+
if self.ref_model is not None:
|
| 877 |
+
if self.is_deepspeed_enabled:
|
| 878 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 879 |
+
else:
|
| 880 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 881 |
+
|
| 882 |
+
if args.sync_ref_model:
|
| 883 |
+
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
| 884 |
+
|
| 885 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 886 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 887 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 888 |
+
|
| 889 |
+
def _set_signature_columns_if_needed(self):
|
| 890 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 891 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 892 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 893 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 894 |
+
if self._signature_columns is None:
|
| 895 |
+
self._signature_columns = ["prompt"]
|
| 896 |
+
|
| 897 |
+
def _get_train_sampler(self) -> Sampler:
|
| 898 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
| 899 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
| 900 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
| 901 |
+
# preventing discrepancies in group formation.
|
| 902 |
+
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
| 903 |
+
|
| 904 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
| 905 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
| 906 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
| 907 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
| 908 |
+
# preventing discrepancies in group formation.
|
| 909 |
+
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
| 910 |
+
|
| 911 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 912 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
| 913 |
+
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
|
| 914 |
+
return None # Unsloth efficient GRPO
|
| 915 |
+
# Otherwise, calculate normally:
|
| 916 |
+
if not hasattr(self, '_autocast_dtype'):
|
| 917 |
+
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
| 918 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
|
| 919 |
+
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
|
| 920 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
| 921 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
| 922 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 923 |
+
|
| 924 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
| 925 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
| 926 |
+
# See https://github.com/huggingface/trl/issues/2770
|
| 927 |
+
logits = logits[:, -logits_to_keep:]
|
| 928 |
+
return logits
|
| 929 |
+
# return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
| 930 |
+
pass
|
| 931 |
+
|
| 932 |
+
def _move_model_to_vllm(self, *args, **kwargs): return None
|
| 933 |
+
|
| 934 |
+
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
| 935 |
+
device = self.accelerator.device
|
| 936 |
+
prompts = [x["prompt"] for x in inputs]
|
| 937 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
| 938 |
+
prompt_inputs = self.processing_class(
|
| 939 |
+
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
| 940 |
+
)
|
| 941 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 942 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 943 |
+
|
| 944 |
+
if self.max_prompt_length is not None:
|
| 945 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 946 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 947 |
+
|
| 948 |
+
# Generate completions using either vLLM or regular generation
|
| 949 |
+
if self.args.use_vllm:
|
| 950 |
+
# First, have main process load weights if needed
|
| 951 |
+
if self.state.global_step != self._last_loaded_step:
|
| 952 |
+
self._move_model_to_vllm()
|
| 953 |
+
self._last_loaded_step = self.state.global_step
|
| 954 |
+
|
| 955 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
| 956 |
+
all_prompts_text = gather_object(prompts_text)
|
| 957 |
+
if self.accelerator.is_main_process:
|
| 958 |
+
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
|
| 959 |
+
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
| 960 |
+
else:
|
| 961 |
+
completion_ids = [None] * len(all_prompts_text)
|
| 962 |
+
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
| 963 |
+
# corresponding slice.
|
| 964 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
| 965 |
+
process_slice = slice(
|
| 966 |
+
self.accelerator.process_index * len(prompts),
|
| 967 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 968 |
+
)
|
| 969 |
+
completion_ids = completion_ids[process_slice]
|
| 970 |
+
|
| 971 |
+
# Pad the completions, and concatenate them with the prompts
|
| 972 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
| 973 |
+
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
| 974 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 975 |
+
else:
|
| 976 |
+
# Regular generation path
|
| 977 |
+
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
| 978 |
+
prompt_completion_ids = unwrapped_model.generate(
|
| 979 |
+
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
# Compute prompt length and extract completion ids
|
| 983 |
+
prompt_length = prompt_ids.size(1)
|
| 984 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 985 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 986 |
+
|
| 987 |
+
# Mask everything after the first EOS token
|
| 988 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 989 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 990 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 991 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 992 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 993 |
+
|
| 994 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 995 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
| 996 |
+
|
| 997 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 998 |
+
|
| 999 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
| 1000 |
+
if self.ref_model is not None:
|
| 1001 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 1002 |
+
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
|
| 1006 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 1007 |
+
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
# Decode the generated completions
|
| 1011 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 1012 |
+
if is_conversational(inputs[0]):
|
| 1013 |
+
completions = []
|
| 1014 |
+
for prompt, completion in zip(prompts, completions_text):
|
| 1015 |
+
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
| 1016 |
+
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
| 1017 |
+
else:
|
| 1018 |
+
completions = completions_text
|
| 1019 |
+
|
| 1020 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 1021 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 1022 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 1023 |
+
):
|
| 1024 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
| 1025 |
+
if is_conversational(inputs[0]):
|
| 1026 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
| 1027 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
| 1028 |
+
else:
|
| 1029 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 1030 |
+
reward_inputs = reward_processing_class(
|
| 1031 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
| 1032 |
+
)
|
| 1033 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 1034 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
| 1035 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
| 1036 |
+
else:
|
| 1037 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 1038 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
| 1039 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
| 1040 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
| 1041 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
| 1042 |
+
|
| 1043 |
+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
| 1044 |
+
# completions may be distributed across processes
|
| 1045 |
+
rewards_per_func = gather(rewards_per_func)
|
| 1046 |
+
|
| 1047 |
+
# Apply weights to each reward function's output and sum
|
| 1048 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
| 1049 |
+
|
| 1050 |
+
# Compute grouped-wise rewards
|
| 1051 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 1052 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 1053 |
+
|
| 1054 |
+
# Normalize the rewards to compute the advantages
|
| 1055 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 1056 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 1057 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 1058 |
+
|
| 1059 |
+
# Slice to keep only the local part of the data
|
| 1060 |
+
process_slice = slice(
|
| 1061 |
+
self.accelerator.process_index * len(prompts),
|
| 1062 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 1063 |
+
)
|
| 1064 |
+
advantages = advantages[process_slice]
|
| 1065 |
+
|
| 1066 |
+
# Log the metrics
|
| 1067 |
+
reward_per_func = rewards_per_func.mean(0)
|
| 1068 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 1069 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
| 1070 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 1071 |
+
else:
|
| 1072 |
+
reward_func_name = reward_func.__name__
|
| 1073 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 1074 |
+
|
| 1075 |
+
self._metrics["reward"].append(rewards.mean().item())
|
| 1076 |
+
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
| 1077 |
+
|
| 1078 |
+
if (
|
| 1079 |
+
self.log_completions
|
| 1080 |
+
and self.state.global_step % self.args.logging_steps == 0
|
| 1081 |
+
and "wandb" in self.args.report_to
|
| 1082 |
+
):
|
| 1083 |
+
import pandas as pd
|
| 1084 |
+
|
| 1085 |
+
# For logging
|
| 1086 |
+
table = {
|
| 1087 |
+
"step": [str(self.state.global_step)] * len(rewards),
|
| 1088 |
+
"prompt": gather_object(prompts_text),
|
| 1089 |
+
"completion": gather_object(completions_text),
|
| 1090 |
+
"reward": rewards.tolist(),
|
| 1091 |
+
}
|
| 1092 |
+
df = pd.DataFrame(table)
|
| 1093 |
+
|
| 1094 |
+
if wandb.run is not None and self.accelerator.is_main_process:
|
| 1095 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 1096 |
+
|
| 1097 |
+
return {
|
| 1098 |
+
"prompt_ids": prompt_ids,
|
| 1099 |
+
"prompt_mask": prompt_mask,
|
| 1100 |
+
"completion_ids": completion_ids,
|
| 1101 |
+
"completion_mask": completion_mask,
|
| 1102 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 1103 |
+
"advantages": advantages,
|
| 1104 |
+
}
|
| 1105 |
+
|
| 1106 |
+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
| 1107 |
+
if return_outputs:
|
| 1108 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 1109 |
+
# Compute the per-token log probabilities for the model
|
| 1110 |
+
|
| 1111 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 1112 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 1113 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 1114 |
+
bsz, qlen = input_ids.shape
|
| 1115 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 1116 |
+
# attention_mask = None
|
| 1117 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 1118 |
+
_input_ids = input_ids
|
| 1119 |
+
_logits_to_keep = logits_to_keep
|
| 1120 |
+
|
| 1121 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
| 1122 |
+
|
| 1123 |
+
# Compute the KL divergence between the model and the reference model
|
| 1124 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 1125 |
+
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 1126 |
+
|
| 1127 |
+
# x - x.detach() allows for preserving gradients from x
|
| 1128 |
+
advantages = inputs["advantages"]
|
| 1129 |
+
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
| 1130 |
+
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
| 1131 |
+
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 1132 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
| 1133 |
+
if per_token_logps is not None:
|
| 1134 |
+
loss, completion_length, mean_kl = grpo_compute_loss_slow(
|
| 1135 |
+
ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
|
| 1136 |
+
)
|
| 1137 |
+
else:
|
| 1138 |
+
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
| 1139 |
+
self, _input_ids, logits_to_keep, completion_mask, advantages,
|
| 1140 |
+
n_chunks = self.args.unsloth_num_chunks,
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
# Log the metrics
|
| 1144 |
+
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 1145 |
+
|
| 1146 |
+
# mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 1147 |
+
# self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 1148 |
+
|
| 1149 |
+
if "train" in self._metrics:
|
| 1150 |
+
mode = "eval" if self.control.should_evaluate else "train"
|
| 1151 |
+
self._metrics[mode]["completion_length"].append(completion_length.item())
|
| 1152 |
+
self._metrics[mode]["kl"].append(mean_kl.item())
|
| 1153 |
+
else:
|
| 1154 |
+
self._metrics["completion_length"].append(completion_length.item())
|
| 1155 |
+
self._metrics["kl"].append(mean_kl.item())
|
| 1156 |
+
return loss
|
| 1157 |
+
|
| 1158 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
| 1159 |
+
inputs = self._prepare_inputs(inputs)
|
| 1160 |
+
with torch.no_grad():
|
| 1161 |
+
with self.compute_loss_context_manager():
|
| 1162 |
+
loss = self.compute_loss(model, inputs)
|
| 1163 |
+
loss = loss.mean().detach()
|
| 1164 |
+
return loss, None, None
|
| 1165 |
+
|
| 1166 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1167 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 1168 |
+
|
| 1169 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 1170 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 1171 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 1172 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 1173 |
+
|
| 1174 |
+
logs = {**logs, **metrics}
|
| 1175 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1176 |
+
super().log(logs, start_time)
|
| 1177 |
+
else: # transformers<=4.46
|
| 1178 |
+
super().log(logs)
|
| 1179 |
+
self._metrics.clear()
|
| 1180 |
+
|
| 1181 |
+
def create_model_card(
|
| 1182 |
+
self,
|
| 1183 |
+
model_name: Optional[str] = None,
|
| 1184 |
+
dataset_name: Optional[str] = None,
|
| 1185 |
+
tags: Union[str, list[str], None] = None,
|
| 1186 |
+
):
|
| 1187 |
+
"""
|
| 1188 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1189 |
+
|
| 1190 |
+
Args:
|
| 1191 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1192 |
+
Name of the model.
|
| 1193 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1194 |
+
Name of the dataset used for training.
|
| 1195 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1196 |
+
Tags to be associated with the model card.
|
| 1197 |
+
"""
|
| 1198 |
+
if not self.is_world_process_zero():
|
| 1199 |
+
return
|
| 1200 |
+
|
| 1201 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1202 |
+
base_model = self.model.config._name_or_path
|
| 1203 |
+
else:
|
| 1204 |
+
base_model = None
|
| 1205 |
+
|
| 1206 |
+
tags = tags or []
|
| 1207 |
+
if isinstance(tags, str):
|
| 1208 |
+
tags = [tags]
|
| 1209 |
+
|
| 1210 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1211 |
+
tags.append("unsloth")
|
| 1212 |
+
|
| 1213 |
+
citation = textwrap.dedent(
|
| 1214 |
+
"""\
|
| 1215 |
+
@article{zhihong2024deepseekmath,
|
| 1216 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
| 1217 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
| 1218 |
+
year = 2024,
|
| 1219 |
+
eprint = {arXiv:2402.03300},
|
| 1220 |
+
}
|
| 1221 |
+
"""
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
model_card = generate_model_card(
|
| 1225 |
+
base_model=base_model,
|
| 1226 |
+
model_name=model_name,
|
| 1227 |
+
hub_model_id=self.hub_model_id,
|
| 1228 |
+
dataset_name=dataset_name,
|
| 1229 |
+
tags=tags,
|
| 1230 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1231 |
+
comet_url=get_comet_experiment_url(),
|
| 1232 |
+
trainer_name="GRPO",
|
| 1233 |
+
trainer_citation=citation,
|
| 1234 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
| 1235 |
+
paper_id="2402.03300",
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1239 |
+
class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
| 1240 |
+
"""
|
| 1241 |
+
|
| 1242 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
| 1243 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
| 1244 |
+
|
| 1245 |
+
Example:
|
| 1246 |
+
|
| 1247 |
+
```python
|
| 1248 |
+
from datasets import load_dataset
|
| 1249 |
+
from trl import GRPOTrainer
|
| 1250 |
+
|
| 1251 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
| 1252 |
+
|
| 1253 |
+
def reward_func(completions, **kwargs):
|
| 1254 |
+
# Dummy reward function that rewards completions with more unique letters.
|
| 1255 |
+
return [float(len(set(completion))) for completion in completions]
|
| 1256 |
+
|
| 1257 |
+
trainer = GRPOTrainer(
|
| 1258 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
| 1259 |
+
reward_funcs=reward_func,
|
| 1260 |
+
train_dataset=dataset,
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
trainer.train()
|
| 1264 |
+
```
|
| 1265 |
+
|
| 1266 |
+
Args:
|
| 1267 |
+
model (`Union[str, PreTrainedModel]`):
|
| 1268 |
+
Model to be trained. Can be either:
|
| 1269 |
+
|
| 1270 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 1271 |
+
a path to a *directory* containing model weights saved using
|
| 1272 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 1273 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 1274 |
+
in `args.model_init_kwargs`.
|
| 1275 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 1276 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
| 1277 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
| 1278 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
| 1279 |
+
|
| 1280 |
+
- A single reward function, such as:
|
| 1281 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 1282 |
+
path to a *directory* containing model weights saved using
|
| 1283 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 1284 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
| 1285 |
+
keyword arguments in `args.model_init_kwargs`.
|
| 1286 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
| 1287 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
| 1288 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
| 1289 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
| 1290 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
| 1291 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
| 1292 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
| 1293 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 1294 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 1295 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
| 1296 |
+
ignored. The format of the samples can be either:
|
| 1297 |
+
|
| 1298 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 1299 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 1300 |
+
and content).
|
| 1301 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 1302 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 1303 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 1304 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
| 1305 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 1306 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
| 1307 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
| 1308 |
+
|
| 1309 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
| 1310 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
| 1311 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
| 1312 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
| 1313 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
| 1314 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
| 1315 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 1316 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 1317 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 1318 |
+
|
| 1319 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 1320 |
+
method.
|
| 1321 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 1322 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 1323 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 1324 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 1325 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 1326 |
+
|
| 1327 |
+
"""
|
| 1328 |
+
def __init__(
|
| 1329 |
+
self,
|
| 1330 |
+
model,
|
| 1331 |
+
reward_funcs,
|
| 1332 |
+
args = None,
|
| 1333 |
+
train_dataset = None,
|
| 1334 |
+
eval_dataset = None,
|
| 1335 |
+
processing_class = None,
|
| 1336 |
+
reward_processing_classes = None,
|
| 1337 |
+
callbacks = None,
|
| 1338 |
+
peft_config = None,
|
| 1339 |
+
**kwargs
|
| 1340 |
+
):
|
| 1341 |
+
if args is None: args = UnslothGRPOConfig()
|
| 1342 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1343 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1344 |
+
force_float32 = False
|
| 1345 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1346 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1347 |
+
force_float32 = True
|
| 1348 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1349 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1350 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1351 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1352 |
+
dtype = _get_dtype(dtype)
|
| 1353 |
+
float16 = dtype == torch.float16
|
| 1354 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1355 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1356 |
+
if force_float32:
|
| 1357 |
+
args.fp16 = False
|
| 1358 |
+
args.bf16 = False
|
| 1359 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1360 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1361 |
+
args.fp16 = float16
|
| 1362 |
+
args.bf16 = not float16
|
| 1363 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1364 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1365 |
+
args.eval_strategy = 'steps'
|
| 1366 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1367 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1368 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1369 |
+
from transformers import __version__ as transformers_version
|
| 1370 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1371 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1372 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1373 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1374 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1375 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1376 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1377 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1378 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1379 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1380 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1381 |
+
if force_float32:
|
| 1382 |
+
args.bf16_full_eval = False
|
| 1383 |
+
args.fp16_full_eval = False
|
| 1384 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1385 |
+
args.bf16_full_eval = True
|
| 1386 |
+
args.fp16_full_eval = False
|
| 1387 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1388 |
+
args.bf16_full_eval = args.bf16
|
| 1389 |
+
args.fp16_full_eval = args.fp16
|
| 1390 |
+
_output_logits = False
|
| 1391 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1392 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1393 |
+
if _output_logits:
|
| 1394 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1395 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1396 |
+
pass
|
| 1397 |
+
else:
|
| 1398 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1399 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1400 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1401 |
+
max_seq_length = model.max_seq_length
|
| 1402 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1403 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1404 |
+
model.for_training()
|
| 1405 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1406 |
+
if 'processing_class' in locals():
|
| 1407 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1408 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1409 |
+
other_metrics = []
|
| 1410 |
+
if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
|
| 1411 |
+
else: _reward_funcs = reward_funcs
|
| 1412 |
+
for reward_func in _reward_funcs:
|
| 1413 |
+
try:
|
| 1414 |
+
reward_func_name = reward_func.__name__
|
| 1415 |
+
other_metrics.append(f'rewards/{reward_func_name}')
|
| 1416 |
+
except: pass
|
| 1417 |
+
|
| 1418 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1419 |
+
PatchRLStatistics('grpo_trainer', other_metrics)
|
| 1420 |
+
|
| 1421 |
+
super().__init__(
|
| 1422 |
+
model = model,
|
| 1423 |
+
reward_funcs = reward_funcs,
|
| 1424 |
+
args = args,
|
| 1425 |
+
train_dataset = train_dataset,
|
| 1426 |
+
eval_dataset = eval_dataset,
|
| 1427 |
+
processing_class = processing_class,
|
| 1428 |
+
reward_processing_classes = reward_processing_classes,
|
| 1429 |
+
callbacks = callbacks,
|
| 1430 |
+
peft_config = peft_config,**kwargs)
|
| 1431 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1432 |
+
self.neftune_hook_handle.remove()
|
| 1433 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1434 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1435 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1436 |
+
pass
|
| 1437 |
+
|
| 1438 |
+
pass
|
unsloth_compiled_cache/UnslothKTOTrainer.py
ADDED
|
@@ -0,0 +1,1840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothKTOConfig(KTOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`KTOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
| 54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 55 |
+
[`~transformers.TrainingArguments`].
|
| 56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 58 |
+
to use the default data collator.
|
| 59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 63 |
+
and your model is an encoder-decoder.
|
| 64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 66 |
+
reference model.
|
| 67 |
+
loss_type (`str`, *optional*, defaults to `"kto"`):
|
| 68 |
+
Type of loss to use. Possible values are:
|
| 69 |
+
|
| 70 |
+
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
| 71 |
+
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
|
| 72 |
+
|
| 73 |
+
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 74 |
+
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
| 75 |
+
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 76 |
+
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
| 77 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 78 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 79 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 80 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 81 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 82 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 83 |
+
This argument is required if you want to use the default data collator.
|
| 84 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 85 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
| 86 |
+
evaluation.
|
| 87 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 88 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 89 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 90 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 91 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 92 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
| 93 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 94 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 95 |
+
string.
|
| 96 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
| 98 |
+
from a string.
|
| 99 |
+
dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
|
| 100 |
+
Number of processes to use for processing the dataset.
|
| 101 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 102 |
+
Whether to disable dropout in the model and reference model.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 106 |
+
default = None,
|
| 107 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 108 |
+
)
|
| 109 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 110 |
+
default = -1,
|
| 111 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 112 |
+
)
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
output_dir = None,
|
| 116 |
+
overwrite_output_dir = None,
|
| 117 |
+
do_train = False,
|
| 118 |
+
do_eval = False,
|
| 119 |
+
do_predict = False,
|
| 120 |
+
eval_strategy = 'no',
|
| 121 |
+
prediction_loss_only = False,
|
| 122 |
+
per_device_train_batch_size = 4,
|
| 123 |
+
per_device_eval_batch_size = 4,
|
| 124 |
+
per_gpu_train_batch_size = None,
|
| 125 |
+
per_gpu_eval_batch_size = None,
|
| 126 |
+
gradient_accumulation_steps = 2,
|
| 127 |
+
eval_accumulation_steps = 2,
|
| 128 |
+
eval_delay = 0,
|
| 129 |
+
torch_empty_cache_steps = 250,
|
| 130 |
+
learning_rate = 5e-05,
|
| 131 |
+
weight_decay = 0.01,
|
| 132 |
+
adam_beta1 = 0.9,
|
| 133 |
+
adam_beta2 = 0.999,
|
| 134 |
+
adam_epsilon = 1e-08,
|
| 135 |
+
max_grad_norm = 1.0,
|
| 136 |
+
num_train_epochs = 3.0,
|
| 137 |
+
max_steps = -1,
|
| 138 |
+
lr_scheduler_type = 'linear',
|
| 139 |
+
warmup_ratio = 0.1,
|
| 140 |
+
warmup_steps = 0,
|
| 141 |
+
log_level = 'passive',
|
| 142 |
+
log_level_replica = 'warning',
|
| 143 |
+
log_on_each_node = True,
|
| 144 |
+
logging_dir = None,
|
| 145 |
+
logging_strategy = 'steps',
|
| 146 |
+
logging_first_step = False,
|
| 147 |
+
logging_steps = 1,
|
| 148 |
+
logging_nan_inf_filter = False,
|
| 149 |
+
save_strategy = 'steps',
|
| 150 |
+
save_steps = 500,
|
| 151 |
+
save_total_limit = None,
|
| 152 |
+
save_safetensors = True,
|
| 153 |
+
save_on_each_node = False,
|
| 154 |
+
save_only_model = False,
|
| 155 |
+
restore_callback_states_from_checkpoint = False,
|
| 156 |
+
no_cuda = False,
|
| 157 |
+
use_cpu = False,
|
| 158 |
+
use_mps_device = False,
|
| 159 |
+
seed = 3407,
|
| 160 |
+
data_seed = 3407,
|
| 161 |
+
jit_mode_eval = False,
|
| 162 |
+
use_ipex = False,
|
| 163 |
+
bf16 = False,
|
| 164 |
+
fp16 = False,
|
| 165 |
+
fp16_opt_level = 'O1',
|
| 166 |
+
half_precision_backend = 'auto',
|
| 167 |
+
bf16_full_eval = False,
|
| 168 |
+
fp16_full_eval = False,
|
| 169 |
+
tf32 = None,
|
| 170 |
+
local_rank = -1,
|
| 171 |
+
ddp_backend = None,
|
| 172 |
+
tpu_num_cores = None,
|
| 173 |
+
tpu_metrics_debug = False,
|
| 174 |
+
debug = '',
|
| 175 |
+
dataloader_drop_last = False,
|
| 176 |
+
eval_steps = None,
|
| 177 |
+
dataloader_num_workers = 0,
|
| 178 |
+
dataloader_prefetch_factor = None,
|
| 179 |
+
past_index = -1,
|
| 180 |
+
run_name = None,
|
| 181 |
+
disable_tqdm = None,
|
| 182 |
+
remove_unused_columns = True,
|
| 183 |
+
label_names = None,
|
| 184 |
+
load_best_model_at_end = False,
|
| 185 |
+
metric_for_best_model = None,
|
| 186 |
+
greater_is_better = None,
|
| 187 |
+
ignore_data_skip = False,
|
| 188 |
+
fsdp = '',
|
| 189 |
+
fsdp_min_num_params = 0,
|
| 190 |
+
fsdp_config = None,
|
| 191 |
+
tp_size = 0,
|
| 192 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 193 |
+
accelerator_config = None,
|
| 194 |
+
deepspeed = None,
|
| 195 |
+
label_smoothing_factor = 0.0,
|
| 196 |
+
optim = 'adamw_8bit',
|
| 197 |
+
optim_args = None,
|
| 198 |
+
adafactor = False,
|
| 199 |
+
group_by_length = False,
|
| 200 |
+
length_column_name = 'length',
|
| 201 |
+
report_to = None,
|
| 202 |
+
ddp_find_unused_parameters = None,
|
| 203 |
+
ddp_bucket_cap_mb = None,
|
| 204 |
+
ddp_broadcast_buffers = None,
|
| 205 |
+
dataloader_pin_memory = True,
|
| 206 |
+
dataloader_persistent_workers = False,
|
| 207 |
+
skip_memory_metrics = True,
|
| 208 |
+
use_legacy_prediction_loop = False,
|
| 209 |
+
push_to_hub = False,
|
| 210 |
+
resume_from_checkpoint = None,
|
| 211 |
+
hub_model_id = None,
|
| 212 |
+
hub_strategy = 'every_save',
|
| 213 |
+
hub_token = None,
|
| 214 |
+
hub_private_repo = None,
|
| 215 |
+
hub_always_push = False,
|
| 216 |
+
gradient_checkpointing = False,
|
| 217 |
+
gradient_checkpointing_kwargs = None,
|
| 218 |
+
include_inputs_for_metrics = False,
|
| 219 |
+
eval_do_concat_batches = True,
|
| 220 |
+
fp16_backend = 'auto',
|
| 221 |
+
evaluation_strategy = None,
|
| 222 |
+
push_to_hub_model_id = None,
|
| 223 |
+
push_to_hub_organization = None,
|
| 224 |
+
push_to_hub_token = None,
|
| 225 |
+
mp_parameters = '',
|
| 226 |
+
auto_find_batch_size = False,
|
| 227 |
+
full_determinism = False,
|
| 228 |
+
torchdynamo = None,
|
| 229 |
+
ray_scope = 'last',
|
| 230 |
+
ddp_timeout = 1800,
|
| 231 |
+
torch_compile = False,
|
| 232 |
+
torch_compile_backend = None,
|
| 233 |
+
torch_compile_mode = None,
|
| 234 |
+
dispatch_batches = None,
|
| 235 |
+
split_batches = None,
|
| 236 |
+
include_tokens_per_second = False,
|
| 237 |
+
include_num_input_tokens_seen = False,
|
| 238 |
+
neftune_noise_alpha = None,
|
| 239 |
+
optim_target_modules = None,
|
| 240 |
+
batch_eval_metrics = False,
|
| 241 |
+
eval_on_start = False,
|
| 242 |
+
use_liger_kernel = False,
|
| 243 |
+
eval_use_gather_object = False,
|
| 244 |
+
average_tokens_across_devices = False,
|
| 245 |
+
max_length = 1024,
|
| 246 |
+
max_prompt_length = 512,
|
| 247 |
+
max_completion_length = None,
|
| 248 |
+
beta = 0.1,
|
| 249 |
+
loss_type = 'kto',
|
| 250 |
+
desirable_weight = 1.0,
|
| 251 |
+
undesirable_weight = 1.0,
|
| 252 |
+
label_pad_token_id = -100,
|
| 253 |
+
padding_value = None,
|
| 254 |
+
truncation_mode = 'keep_end',
|
| 255 |
+
generate_during_eval = False,
|
| 256 |
+
is_encoder_decoder = None,
|
| 257 |
+
disable_dropout = True,
|
| 258 |
+
precompute_ref_log_probs = False,
|
| 259 |
+
model_init_kwargs = None,
|
| 260 |
+
ref_model_init_kwargs = None,
|
| 261 |
+
dataset_num_proc = None,
|
| 262 |
+
vllm_sampling_params = None,
|
| 263 |
+
unsloth_num_chunks = -1,
|
| 264 |
+
**kwargs,
|
| 265 |
+
):
|
| 266 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 267 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 268 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 269 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 270 |
+
save_strategy = 'no'
|
| 271 |
+
if dataset_num_proc is None:
|
| 272 |
+
from multiprocessing import cpu_count
|
| 273 |
+
dataset_num_proc = cpu_count()
|
| 274 |
+
|
| 275 |
+
super().__init__(
|
| 276 |
+
output_dir = output_dir,
|
| 277 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 278 |
+
do_train = do_train,
|
| 279 |
+
do_eval = do_eval,
|
| 280 |
+
do_predict = do_predict,
|
| 281 |
+
eval_strategy = eval_strategy,
|
| 282 |
+
prediction_loss_only = prediction_loss_only,
|
| 283 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 284 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 285 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 286 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 287 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 288 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 289 |
+
eval_delay = eval_delay,
|
| 290 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 291 |
+
learning_rate = learning_rate,
|
| 292 |
+
weight_decay = weight_decay,
|
| 293 |
+
adam_beta1 = adam_beta1,
|
| 294 |
+
adam_beta2 = adam_beta2,
|
| 295 |
+
adam_epsilon = adam_epsilon,
|
| 296 |
+
max_grad_norm = max_grad_norm,
|
| 297 |
+
num_train_epochs = num_train_epochs,
|
| 298 |
+
max_steps = max_steps,
|
| 299 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 300 |
+
warmup_ratio = warmup_ratio,
|
| 301 |
+
warmup_steps = warmup_steps,
|
| 302 |
+
log_level = log_level,
|
| 303 |
+
log_level_replica = log_level_replica,
|
| 304 |
+
log_on_each_node = log_on_each_node,
|
| 305 |
+
logging_dir = logging_dir,
|
| 306 |
+
logging_strategy = logging_strategy,
|
| 307 |
+
logging_first_step = logging_first_step,
|
| 308 |
+
logging_steps = logging_steps,
|
| 309 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 310 |
+
save_strategy = save_strategy,
|
| 311 |
+
save_steps = save_steps,
|
| 312 |
+
save_total_limit = save_total_limit,
|
| 313 |
+
save_safetensors = save_safetensors,
|
| 314 |
+
save_on_each_node = save_on_each_node,
|
| 315 |
+
save_only_model = save_only_model,
|
| 316 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 317 |
+
no_cuda = no_cuda,
|
| 318 |
+
use_cpu = use_cpu,
|
| 319 |
+
use_mps_device = use_mps_device,
|
| 320 |
+
seed = seed,
|
| 321 |
+
data_seed = data_seed,
|
| 322 |
+
jit_mode_eval = jit_mode_eval,
|
| 323 |
+
use_ipex = use_ipex,
|
| 324 |
+
bf16 = bf16,
|
| 325 |
+
fp16 = fp16,
|
| 326 |
+
fp16_opt_level = fp16_opt_level,
|
| 327 |
+
half_precision_backend = half_precision_backend,
|
| 328 |
+
bf16_full_eval = bf16_full_eval,
|
| 329 |
+
fp16_full_eval = fp16_full_eval,
|
| 330 |
+
tf32 = tf32,
|
| 331 |
+
local_rank = local_rank,
|
| 332 |
+
ddp_backend = ddp_backend,
|
| 333 |
+
tpu_num_cores = tpu_num_cores,
|
| 334 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 335 |
+
debug = debug,
|
| 336 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 337 |
+
eval_steps = eval_steps,
|
| 338 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 339 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 340 |
+
past_index = past_index,
|
| 341 |
+
run_name = run_name,
|
| 342 |
+
disable_tqdm = disable_tqdm,
|
| 343 |
+
remove_unused_columns = remove_unused_columns,
|
| 344 |
+
label_names = label_names,
|
| 345 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 346 |
+
metric_for_best_model = metric_for_best_model,
|
| 347 |
+
greater_is_better = greater_is_better,
|
| 348 |
+
ignore_data_skip = ignore_data_skip,
|
| 349 |
+
fsdp = fsdp,
|
| 350 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 351 |
+
fsdp_config = fsdp_config,
|
| 352 |
+
tp_size = tp_size,
|
| 353 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 354 |
+
accelerator_config = accelerator_config,
|
| 355 |
+
deepspeed = deepspeed,
|
| 356 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 357 |
+
optim = optim,
|
| 358 |
+
optim_args = optim_args,
|
| 359 |
+
adafactor = adafactor,
|
| 360 |
+
group_by_length = group_by_length,
|
| 361 |
+
length_column_name = length_column_name,
|
| 362 |
+
report_to = report_to,
|
| 363 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 364 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 365 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 366 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 367 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 368 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 369 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 370 |
+
push_to_hub = push_to_hub,
|
| 371 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 372 |
+
hub_model_id = hub_model_id,
|
| 373 |
+
hub_strategy = hub_strategy,
|
| 374 |
+
hub_token = hub_token,
|
| 375 |
+
hub_private_repo = hub_private_repo,
|
| 376 |
+
hub_always_push = hub_always_push,
|
| 377 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 378 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 379 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 380 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 381 |
+
fp16_backend = fp16_backend,
|
| 382 |
+
evaluation_strategy = evaluation_strategy,
|
| 383 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 384 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 385 |
+
push_to_hub_token = push_to_hub_token,
|
| 386 |
+
mp_parameters = mp_parameters,
|
| 387 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 388 |
+
full_determinism = full_determinism,
|
| 389 |
+
torchdynamo = torchdynamo,
|
| 390 |
+
ray_scope = ray_scope,
|
| 391 |
+
ddp_timeout = ddp_timeout,
|
| 392 |
+
torch_compile = torch_compile,
|
| 393 |
+
torch_compile_backend = torch_compile_backend,
|
| 394 |
+
torch_compile_mode = torch_compile_mode,
|
| 395 |
+
dispatch_batches = dispatch_batches,
|
| 396 |
+
split_batches = split_batches,
|
| 397 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 398 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 399 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 400 |
+
optim_target_modules = optim_target_modules,
|
| 401 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 402 |
+
eval_on_start = eval_on_start,
|
| 403 |
+
use_liger_kernel = use_liger_kernel,
|
| 404 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 405 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 406 |
+
max_length = max_length,
|
| 407 |
+
max_prompt_length = max_prompt_length,
|
| 408 |
+
max_completion_length = max_completion_length,
|
| 409 |
+
beta = beta,
|
| 410 |
+
loss_type = loss_type,
|
| 411 |
+
desirable_weight = desirable_weight,
|
| 412 |
+
undesirable_weight = undesirable_weight,
|
| 413 |
+
label_pad_token_id = label_pad_token_id,
|
| 414 |
+
padding_value = padding_value,
|
| 415 |
+
truncation_mode = truncation_mode,
|
| 416 |
+
generate_during_eval = generate_during_eval,
|
| 417 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 418 |
+
disable_dropout = disable_dropout,
|
| 419 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
| 420 |
+
model_init_kwargs = model_init_kwargs,
|
| 421 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
| 422 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 423 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 424 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 425 |
+
pass
|
| 426 |
+
|
| 427 |
+
class _UnslothKTOTrainer(Trainer):
|
| 428 |
+
r""""""
|
| 429 |
+
|
| 430 |
+
_tag_names = ["trl", "kto"]
|
| 431 |
+
|
| 432 |
+
def __init__(
|
| 433 |
+
self,
|
| 434 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 435 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 436 |
+
args: KTOConfig = None,
|
| 437 |
+
train_dataset: Optional[Dataset] = None,
|
| 438 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 439 |
+
processing_class: Optional[
|
| 440 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 441 |
+
] = None,
|
| 442 |
+
data_collator: Optional[DataCollator] = None,
|
| 443 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 444 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 445 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 446 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 447 |
+
peft_config: Optional[dict] = None,
|
| 448 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 449 |
+
model_adapter_name: Optional[str] = None,
|
| 450 |
+
ref_adapter_name: Optional[str] = None,
|
| 451 |
+
):
|
| 452 |
+
if type(args) is TrainingArguments:
|
| 453 |
+
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
| 454 |
+
|
| 455 |
+
if not isinstance(model, str) and ref_model is model:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 458 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if args.model_init_kwargs is None:
|
| 462 |
+
model_init_kwargs = {}
|
| 463 |
+
elif not isinstance(model, str):
|
| 464 |
+
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
| 465 |
+
else:
|
| 466 |
+
model_init_kwargs = args.model_init_kwargs
|
| 467 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 468 |
+
if torch_dtype is not None:
|
| 469 |
+
# Convert to `torch.dtype` if an str is passed
|
| 470 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 471 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 472 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 475 |
+
)
|
| 476 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 477 |
+
|
| 478 |
+
if args.ref_model_init_kwargs is None:
|
| 479 |
+
ref_model_init_kwargs = {}
|
| 480 |
+
elif not isinstance(ref_model, str):
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
| 486 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
| 487 |
+
if torch_dtype is not None:
|
| 488 |
+
# Convert to `torch.dtype` if an str is passed
|
| 489 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 490 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 491 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 494 |
+
)
|
| 495 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
| 496 |
+
|
| 497 |
+
if isinstance(model, str):
|
| 498 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 499 |
+
|
| 500 |
+
if isinstance(ref_model, str):
|
| 501 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
| 502 |
+
|
| 503 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 504 |
+
# has been called in order to properly call autocast if needed.
|
| 505 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 506 |
+
|
| 507 |
+
if not is_peft_available() and peft_config is not None:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 510 |
+
)
|
| 511 |
+
elif is_peft_available() and peft_config is not None:
|
| 512 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 513 |
+
if isinstance(model, PeftModel):
|
| 514 |
+
model = model.merge_and_unload()
|
| 515 |
+
|
| 516 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 517 |
+
_support_gc_kwargs = hasattr(
|
| 518 |
+
args, "gradient_checkpointing_kwargs"
|
| 519 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 520 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 524 |
+
|
| 525 |
+
if _support_gc_kwargs:
|
| 526 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 527 |
+
|
| 528 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 529 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 530 |
+
# For backward compatibility with older versions of transformers
|
| 531 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 532 |
+
model.enable_input_require_grads()
|
| 533 |
+
else:
|
| 534 |
+
|
| 535 |
+
def make_inputs_require_grad(module, input, output):
|
| 536 |
+
output.requires_grad_(True)
|
| 537 |
+
|
| 538 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 539 |
+
|
| 540 |
+
# get peft model with the given config
|
| 541 |
+
model = model
|
| 542 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 543 |
+
peft_module_casting_to_bf16(model)
|
| 544 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 545 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 546 |
+
|
| 547 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 548 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 549 |
+
# fail or completely fail.
|
| 550 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 551 |
+
# For backward compatibility with older versions of transformers
|
| 552 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 553 |
+
model.enable_input_require_grads()
|
| 554 |
+
else:
|
| 555 |
+
|
| 556 |
+
def make_inputs_require_grad(module, input, output):
|
| 557 |
+
output.requires_grad_(True)
|
| 558 |
+
|
| 559 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 560 |
+
|
| 561 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 564 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if model is not None:
|
| 568 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 569 |
+
elif args.is_encoder_decoder is None:
|
| 570 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 571 |
+
else:
|
| 572 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 573 |
+
|
| 574 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 575 |
+
self.model_adapter_name = model_adapter_name
|
| 576 |
+
self.ref_adapter_name = ref_adapter_name
|
| 577 |
+
|
| 578 |
+
if ref_model:
|
| 579 |
+
self.ref_model = ref_model
|
| 580 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 581 |
+
# The `model` with adapters turned off will be used as the reference model
|
| 582 |
+
self.ref_model = None
|
| 583 |
+
else:
|
| 584 |
+
self.ref_model = create_reference_model(model)
|
| 585 |
+
|
| 586 |
+
if processing_class is None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 589 |
+
)
|
| 590 |
+
if args.max_length is None:
|
| 591 |
+
warnings.warn(
|
| 592 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
| 593 |
+
" it will be set to `512` by default, but you should do it yourself in the future.",
|
| 594 |
+
UserWarning,
|
| 595 |
+
)
|
| 596 |
+
max_length = 512
|
| 597 |
+
if args.max_length is not None:
|
| 598 |
+
max_length = args.max_length
|
| 599 |
+
|
| 600 |
+
if args.max_prompt_length is None:
|
| 601 |
+
warnings.warn(
|
| 602 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
| 603 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 604 |
+
UserWarning,
|
| 605 |
+
)
|
| 606 |
+
max_prompt_length = 128
|
| 607 |
+
if args.max_prompt_length is not None:
|
| 608 |
+
max_prompt_length = args.max_prompt_length
|
| 609 |
+
|
| 610 |
+
max_completion_length = None
|
| 611 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 612 |
+
warnings.warn(
|
| 613 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
| 614 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 615 |
+
UserWarning,
|
| 616 |
+
)
|
| 617 |
+
max_completion_length = 128
|
| 618 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
| 619 |
+
max_completion_length = args.max_completion_length
|
| 620 |
+
|
| 621 |
+
if data_collator is None:
|
| 622 |
+
data_collator = DPODataCollatorWithPadding(
|
| 623 |
+
pad_token_id=processing_class.pad_token_id,
|
| 624 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 625 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if args.remove_unused_columns:
|
| 629 |
+
args.remove_unused_columns = False
|
| 630 |
+
# warn users
|
| 631 |
+
warnings.warn(
|
| 632 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
| 633 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 634 |
+
UserWarning,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
self.use_dpo_data_collator = True
|
| 638 |
+
else:
|
| 639 |
+
self.use_dpo_data_collator = False
|
| 640 |
+
|
| 641 |
+
# Disable dropout in the model and reference model
|
| 642 |
+
if args.disable_dropout:
|
| 643 |
+
disable_dropout_in_model(model)
|
| 644 |
+
if self.ref_model is not None:
|
| 645 |
+
disable_dropout_in_model(self.ref_model)
|
| 646 |
+
|
| 647 |
+
self.loss_type = args.loss_type
|
| 648 |
+
self.max_length = max_length
|
| 649 |
+
self.generate_during_eval = args.generate_during_eval
|
| 650 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 651 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 652 |
+
self.max_prompt_length = max_prompt_length
|
| 653 |
+
self.truncation_mode = args.truncation_mode
|
| 654 |
+
self.max_completion_length = max_completion_length
|
| 655 |
+
self.processing_class = processing_class
|
| 656 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 657 |
+
|
| 658 |
+
# Not all losses require a KL calculation
|
| 659 |
+
self.calculate_KL = True
|
| 660 |
+
if self.loss_type in ["apo_zero_unpaired"]:
|
| 661 |
+
self.calculate_KL = False
|
| 662 |
+
|
| 663 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 664 |
+
# keep track of first called to avoid computation of future calls
|
| 665 |
+
self._precomputed_train_ref_log_probs = False
|
| 666 |
+
self._precomputed_eval_ref_log_probs = False
|
| 667 |
+
|
| 668 |
+
# metric
|
| 669 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 670 |
+
|
| 671 |
+
# KTO parameter
|
| 672 |
+
self.beta = args.beta
|
| 673 |
+
self.desirable_weight = args.desirable_weight
|
| 674 |
+
self.undesirable_weight = args.undesirable_weight
|
| 675 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 676 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 677 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 678 |
+
warnings.warn(
|
| 679 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 680 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 681 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 682 |
+
"loss.",
|
| 683 |
+
UserWarning,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 687 |
+
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
| 688 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
| 689 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 690 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 691 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 692 |
+
# issued.
|
| 693 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 694 |
+
|
| 695 |
+
# Compute that only on the main process for faster data processing.
|
| 696 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 697 |
+
with PartialState().local_main_process_first():
|
| 698 |
+
# Extract the prompt if needed
|
| 699 |
+
train_dataset = train_dataset.map(
|
| 700 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
| 701 |
+
)
|
| 702 |
+
# Unpair the dataset if needed
|
| 703 |
+
train_dataset = maybe_unpair_preference_dataset(
|
| 704 |
+
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
| 705 |
+
)
|
| 706 |
+
# Apply the chat template if needed
|
| 707 |
+
train_dataset = train_dataset.map(
|
| 708 |
+
maybe_apply_chat_template,
|
| 709 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 710 |
+
num_proc=args.dataset_num_proc,
|
| 711 |
+
desc="Applying chat template to train dataset",
|
| 712 |
+
)
|
| 713 |
+
if eval_dataset is not None:
|
| 714 |
+
eval_dataset = eval_dataset.map(
|
| 715 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
| 716 |
+
)
|
| 717 |
+
eval_dataset = maybe_unpair_preference_dataset(
|
| 718 |
+
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
| 719 |
+
)
|
| 720 |
+
eval_dataset = eval_dataset.map(
|
| 721 |
+
maybe_apply_chat_template,
|
| 722 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 723 |
+
num_proc=args.dataset_num_proc,
|
| 724 |
+
desc="Applying chat template to eval dataset",
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# Tokenize and prepare the training datasets
|
| 728 |
+
train_dataset = train_dataset.map(
|
| 729 |
+
_tokenize,
|
| 730 |
+
batched=True,
|
| 731 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
| 732 |
+
num_proc=args.dataset_num_proc,
|
| 733 |
+
desc="Tokenizing train dataset",
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
fn_kwargs = {
|
| 737 |
+
"prefix": "",
|
| 738 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
| 739 |
+
"tokenizer": self.processing_class,
|
| 740 |
+
"max_length": self.max_length,
|
| 741 |
+
"truncation_mode": self.truncation_mode,
|
| 742 |
+
"label_pad_token_id": self.label_pad_token_id,
|
| 743 |
+
"max_prompt_length": self.max_prompt_length,
|
| 744 |
+
"max_completion_length": self.max_completion_length,
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
train_dataset = train_dataset.map(
|
| 748 |
+
_process_tokens,
|
| 749 |
+
fn_kwargs=fn_kwargs,
|
| 750 |
+
num_proc=args.dataset_num_proc,
|
| 751 |
+
desc="Processing tokenized train dataset",
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# Tokenize and prepare the eval datasets
|
| 755 |
+
if eval_dataset is not None:
|
| 756 |
+
eval_dataset = eval_dataset.map(
|
| 757 |
+
_tokenize,
|
| 758 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
| 759 |
+
batched=True,
|
| 760 |
+
num_proc=args.dataset_num_proc,
|
| 761 |
+
desc="Tokenizing eval dataset",
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
eval_dataset = eval_dataset.map(
|
| 765 |
+
_process_tokens,
|
| 766 |
+
fn_kwargs=fn_kwargs,
|
| 767 |
+
num_proc=args.dataset_num_proc,
|
| 768 |
+
desc="Processing tokenized eval dataset",
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Get KL datasets if needed
|
| 772 |
+
if self.calculate_KL:
|
| 773 |
+
if args.per_device_train_batch_size <= 1:
|
| 774 |
+
raise ValueError(
|
| 775 |
+
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
| 779 |
+
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
| 780 |
+
train_kl_dataset = train_dataset.map(
|
| 781 |
+
_get_kl_dataset,
|
| 782 |
+
batched=True,
|
| 783 |
+
batch_size=args.per_device_train_batch_size,
|
| 784 |
+
num_proc=args.dataset_num_proc,
|
| 785 |
+
desc="Extracting KL train dataset",
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
fn_kwargs["prefix"] = "KL_"
|
| 789 |
+
train_kl_dataset = train_kl_dataset.map(
|
| 790 |
+
_process_tokens,
|
| 791 |
+
fn_kwargs=fn_kwargs,
|
| 792 |
+
num_proc=args.dataset_num_proc,
|
| 793 |
+
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
| 794 |
+
desc="Processing tokenized train KL dataset",
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# merge the datasets
|
| 798 |
+
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
| 799 |
+
|
| 800 |
+
if eval_dataset is not None:
|
| 801 |
+
# Get KL dataset
|
| 802 |
+
eval_kl_dataset = eval_dataset.map(
|
| 803 |
+
_get_kl_dataset,
|
| 804 |
+
batched=True,
|
| 805 |
+
batch_size=args.per_device_train_batch_size,
|
| 806 |
+
num_proc=args.dataset_num_proc,
|
| 807 |
+
desc="Extracting eval KL dataset",
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
eval_kl_dataset = eval_kl_dataset.map(
|
| 811 |
+
_process_tokens,
|
| 812 |
+
fn_kwargs=fn_kwargs,
|
| 813 |
+
num_proc=args.dataset_num_proc,
|
| 814 |
+
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
| 815 |
+
desc="Processing tokenized eval KL dataset",
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
# merge the datasets
|
| 819 |
+
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
| 820 |
+
|
| 821 |
+
# calculate dataset desirability balance
|
| 822 |
+
num_desirable = max(sum(train_dataset["label"]), 1)
|
| 823 |
+
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
| 824 |
+
|
| 825 |
+
if num_desirable != num_undesirable:
|
| 826 |
+
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
| 827 |
+
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
| 828 |
+
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
| 829 |
+
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
| 830 |
+
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
| 831 |
+
|
| 832 |
+
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
| 833 |
+
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
| 834 |
+
|
| 835 |
+
if not (des_weight_in_range or und_weight_in_range):
|
| 836 |
+
warnings.warn(
|
| 837 |
+
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
| 838 |
+
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
| 839 |
+
f"on your data, we recommend EITHER "
|
| 840 |
+
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
| 841 |
+
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
| 842 |
+
"See the documentation on how to optimally set these weights.",
|
| 843 |
+
UserWarning,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
super().__init__(
|
| 847 |
+
model=model,
|
| 848 |
+
args=args,
|
| 849 |
+
data_collator=data_collator,
|
| 850 |
+
train_dataset=train_dataset,
|
| 851 |
+
eval_dataset=eval_dataset,
|
| 852 |
+
processing_class=processing_class,
|
| 853 |
+
model_init=model_init,
|
| 854 |
+
compute_metrics=compute_metrics,
|
| 855 |
+
callbacks=callbacks,
|
| 856 |
+
optimizers=optimizers,
|
| 857 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 861 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 862 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 863 |
+
self.model_accepts_loss_kwargs = False
|
| 864 |
+
|
| 865 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 866 |
+
if hasattr(self.model, "add_model_tags"):
|
| 867 |
+
self.model.add_model_tags(self._tag_names)
|
| 868 |
+
|
| 869 |
+
if not hasattr(self, "accelerator"):
|
| 870 |
+
raise AttributeError(
|
| 871 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 875 |
+
if self.is_deepspeed_enabled:
|
| 876 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 877 |
+
raise ValueError(
|
| 878 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
if self.ref_model is None:
|
| 882 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 883 |
+
raise ValueError(
|
| 884 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
if self.is_deepspeed_enabled:
|
| 888 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
| 889 |
+
else:
|
| 890 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 891 |
+
|
| 892 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 893 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 894 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 895 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 896 |
+
|
| 897 |
+
if model is not None:
|
| 898 |
+
if hasattr(model, "config"):
|
| 899 |
+
hidden_size = (
|
| 900 |
+
max(model.config.hidden_sizes)
|
| 901 |
+
if getattr(model.config, "hidden_sizes", None)
|
| 902 |
+
else getattr(model.config, "hidden_size", None)
|
| 903 |
+
)
|
| 904 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 905 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 906 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 907 |
+
config_kwargs.update(
|
| 908 |
+
{
|
| 909 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 910 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 911 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 912 |
+
}
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 916 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 917 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 918 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
| 919 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 920 |
+
model.eval()
|
| 921 |
+
return model
|
| 922 |
+
|
| 923 |
+
@contextmanager
|
| 924 |
+
def null_ref_context(self):
|
| 925 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 926 |
+
with (
|
| 927 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 928 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 929 |
+
else nullcontext()
|
| 930 |
+
):
|
| 931 |
+
if self.ref_adapter_name:
|
| 932 |
+
self.model.set_adapter(self.ref_adapter_name)
|
| 933 |
+
yield
|
| 934 |
+
if self.ref_adapter_name:
|
| 935 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
| 936 |
+
|
| 937 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 938 |
+
"""
|
| 939 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
| 940 |
+
|
| 941 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 942 |
+
"""
|
| 943 |
+
|
| 944 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 945 |
+
dataloader_params = {
|
| 946 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 947 |
+
"collate_fn": self.data_collator,
|
| 948 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 949 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 950 |
+
"shuffle": False,
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
# prepare dataloader
|
| 954 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 955 |
+
reference_completion_logps = []
|
| 956 |
+
reference_KL_logps = []
|
| 957 |
+
|
| 958 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 959 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 960 |
+
|
| 961 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 962 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 963 |
+
|
| 964 |
+
if self.calculate_KL:
|
| 965 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 966 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 967 |
+
|
| 968 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 969 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if self.calculate_KL:
|
| 973 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 974 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
self._precomputed_train_ref_log_probs = True
|
| 978 |
+
|
| 979 |
+
return super().get_train_dataloader()
|
| 980 |
+
|
| 981 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 982 |
+
"""
|
| 983 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 984 |
+
|
| 985 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 986 |
+
|
| 987 |
+
Args:
|
| 988 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 989 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 990 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 991 |
+
"""
|
| 992 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 993 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 994 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 995 |
+
|
| 996 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 997 |
+
dataloader_params = {
|
| 998 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
| 999 |
+
"collate_fn": self.data_collator,
|
| 1000 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 1001 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 1002 |
+
"shuffle": False,
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
# prepare dataloader
|
| 1006 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 1007 |
+
|
| 1008 |
+
reference_completion_logps = []
|
| 1009 |
+
reference_KL_logps = []
|
| 1010 |
+
|
| 1011 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 1012 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 1013 |
+
|
| 1014 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1015 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1016 |
+
|
| 1017 |
+
if self.calculate_KL:
|
| 1018 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 1019 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 1020 |
+
|
| 1021 |
+
eval_dataset = eval_dataset.add_column(
|
| 1022 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1023 |
+
)
|
| 1024 |
+
if self.calculate_KL:
|
| 1025 |
+
eval_dataset = eval_dataset.add_column(
|
| 1026 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 1030 |
+
if self.eval_dataset is not None:
|
| 1031 |
+
self.eval_dataset = eval_dataset
|
| 1032 |
+
self._precomputed_eval_ref_log_probs = True
|
| 1033 |
+
|
| 1034 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 1035 |
+
|
| 1036 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 1037 |
+
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
| 1038 |
+
with torch.no_grad():
|
| 1039 |
+
if self.ref_model is None:
|
| 1040 |
+
with self.null_ref_context():
|
| 1041 |
+
if self.is_encoder_decoder:
|
| 1042 |
+
completion_logits = self.model(
|
| 1043 |
+
padded_batch["prompt_input_ids"],
|
| 1044 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1045 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1046 |
+
labels=padded_batch["completion_labels"],
|
| 1047 |
+
).logits
|
| 1048 |
+
|
| 1049 |
+
if self.calculate_KL:
|
| 1050 |
+
KL_logits = self.model(
|
| 1051 |
+
padded_batch["KL_prompt_input_ids"],
|
| 1052 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
| 1053 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
| 1054 |
+
labels=padded_batch["KL_completion_labels"],
|
| 1055 |
+
).logits
|
| 1056 |
+
else:
|
| 1057 |
+
completion_logits = self.model(
|
| 1058 |
+
padded_batch["completion_input_ids"],
|
| 1059 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
| 1060 |
+
).logits
|
| 1061 |
+
|
| 1062 |
+
if self.calculate_KL:
|
| 1063 |
+
KL_logits = self.model(
|
| 1064 |
+
padded_batch["KL_completion_input_ids"],
|
| 1065 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 1066 |
+
).logits
|
| 1067 |
+
else:
|
| 1068 |
+
if self.is_encoder_decoder:
|
| 1069 |
+
completion_logits = self.ref_model(
|
| 1070 |
+
padded_batch["prompt_input_ids"],
|
| 1071 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1072 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1073 |
+
labels=padded_batch["completion_labels"],
|
| 1074 |
+
).logits
|
| 1075 |
+
|
| 1076 |
+
if self.calculate_KL:
|
| 1077 |
+
KL_logits = self.ref_model(
|
| 1078 |
+
padded_batch["KL_prompt_input_ids"],
|
| 1079 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
| 1080 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
| 1081 |
+
labels=padded_batch["KL_completion_labels"],
|
| 1082 |
+
).logits
|
| 1083 |
+
else:
|
| 1084 |
+
completion_logits = self.ref_model(
|
| 1085 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 1086 |
+
).logits
|
| 1087 |
+
|
| 1088 |
+
if self.calculate_KL:
|
| 1089 |
+
KL_logits = self.ref_model(
|
| 1090 |
+
padded_batch["KL_completion_input_ids"],
|
| 1091 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 1092 |
+
).logits
|
| 1093 |
+
|
| 1094 |
+
completion_logps = self.get_batch_logps(
|
| 1095 |
+
completion_logits,
|
| 1096 |
+
padded_batch["completion_labels"],
|
| 1097 |
+
average_log_prob=False,
|
| 1098 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1099 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
if self.calculate_KL:
|
| 1103 |
+
KL_logps = self.get_batch_logps(
|
| 1104 |
+
KL_logits,
|
| 1105 |
+
padded_batch["KL_completion_labels"],
|
| 1106 |
+
average_log_prob=False,
|
| 1107 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1108 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1109 |
+
)
|
| 1110 |
+
else:
|
| 1111 |
+
KL_logps = None
|
| 1112 |
+
|
| 1113 |
+
return completion_logps, KL_logps
|
| 1114 |
+
|
| 1115 |
+
@staticmethod
|
| 1116 |
+
def get_batch_logps(
|
| 1117 |
+
logits: torch.FloatTensor,
|
| 1118 |
+
labels: torch.LongTensor,
|
| 1119 |
+
average_log_prob: bool = False,
|
| 1120 |
+
label_pad_token_id: int = -100,
|
| 1121 |
+
is_encoder_decoder: bool = False,
|
| 1122 |
+
) -> torch.FloatTensor:
|
| 1123 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1124 |
+
|
| 1125 |
+
Args:
|
| 1126 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1127 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1128 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1129 |
+
|
| 1130 |
+
Returns:
|
| 1131 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1132 |
+
"""
|
| 1133 |
+
if logits.shape[:-1] != labels.shape:
|
| 1134 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1135 |
+
|
| 1136 |
+
if not is_encoder_decoder:
|
| 1137 |
+
labels = labels[:, 1:].clone()
|
| 1138 |
+
logits = logits[:, :-1, :]
|
| 1139 |
+
else:
|
| 1140 |
+
# Fixes end-dec RuntimeError
|
| 1141 |
+
labels = labels.clone()
|
| 1142 |
+
|
| 1143 |
+
loss_mask = labels != label_pad_token_id
|
| 1144 |
+
|
| 1145 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1146 |
+
labels[labels == label_pad_token_id] = 0
|
| 1147 |
+
|
| 1148 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1149 |
+
|
| 1150 |
+
if average_log_prob:
|
| 1151 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1152 |
+
else:
|
| 1153 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1154 |
+
|
| 1155 |
+
def forward(
|
| 1156 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1157 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1158 |
+
if self.calculate_KL:
|
| 1159 |
+
KL_logps = None
|
| 1160 |
+
KL_model_kwargs = (
|
| 1161 |
+
{
|
| 1162 |
+
"input_ids": batch["KL_prompt_input_ids"],
|
| 1163 |
+
"attention_mask": batch["KL_prompt_attention_mask"],
|
| 1164 |
+
"labels": batch["KL_completion_labels"],
|
| 1165 |
+
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
|
| 1166 |
+
}
|
| 1167 |
+
if self.is_encoder_decoder
|
| 1168 |
+
else {
|
| 1169 |
+
"input_ids": batch["KL_completion_input_ids"],
|
| 1170 |
+
"attention_mask": batch["KL_completion_attention_mask"],
|
| 1171 |
+
}
|
| 1172 |
+
)
|
| 1173 |
+
with torch.no_grad():
|
| 1174 |
+
KL_logits = model(
|
| 1175 |
+
**KL_model_kwargs,
|
| 1176 |
+
).logits
|
| 1177 |
+
|
| 1178 |
+
KL_logps = self.get_batch_logps(
|
| 1179 |
+
KL_logits,
|
| 1180 |
+
batch["KL_completion_labels"],
|
| 1181 |
+
average_log_prob=False,
|
| 1182 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1183 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1184 |
+
)
|
| 1185 |
+
else:
|
| 1186 |
+
KL_logps = None
|
| 1187 |
+
|
| 1188 |
+
model_kwargs = (
|
| 1189 |
+
{
|
| 1190 |
+
"labels": batch["completion_labels"],
|
| 1191 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
| 1192 |
+
}
|
| 1193 |
+
if self.is_encoder_decoder
|
| 1194 |
+
else {}
|
| 1195 |
+
)
|
| 1196 |
+
if self.aux_loss_enabled:
|
| 1197 |
+
model_kwargs["output_router_logits"] = True
|
| 1198 |
+
|
| 1199 |
+
outputs = model(
|
| 1200 |
+
batch["completion_input_ids"],
|
| 1201 |
+
attention_mask=batch["completion_attention_mask"],
|
| 1202 |
+
**model_kwargs,
|
| 1203 |
+
)
|
| 1204 |
+
completion_logits = outputs.logits
|
| 1205 |
+
|
| 1206 |
+
completion_logps = self.get_batch_logps(
|
| 1207 |
+
completion_logits,
|
| 1208 |
+
batch["completion_labels"],
|
| 1209 |
+
average_log_prob=False,
|
| 1210 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1211 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
| 1215 |
+
raise ValueError(
|
| 1216 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
| 1217 |
+
"examples for which an output sequence was predicted."
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
| 1221 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
| 1222 |
+
|
| 1223 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
| 1224 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
| 1225 |
+
|
| 1226 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
| 1227 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
| 1228 |
+
|
| 1229 |
+
if self.aux_loss_enabled:
|
| 1230 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
| 1231 |
+
else:
|
| 1232 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
| 1233 |
+
|
| 1234 |
+
def kto_loss(
|
| 1235 |
+
self,
|
| 1236 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 1237 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 1238 |
+
policy_KL_logps: torch.FloatTensor,
|
| 1239 |
+
reference_chosen_logps: torch.FloatTensor,
|
| 1240 |
+
reference_rejected_logps: torch.FloatTensor,
|
| 1241 |
+
reference_KL_logps: torch.FloatTensor,
|
| 1242 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1243 |
+
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
| 1244 |
+
|
| 1245 |
+
Args:
|
| 1246 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1247 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1248 |
+
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
| 1249 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1250 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1251 |
+
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
| 1252 |
+
|
| 1253 |
+
Returns:
|
| 1254 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
| 1255 |
+
The losses tensor contains the KTO loss for each example in the batch.
|
| 1256 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 1257 |
+
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
| 1258 |
+
"""
|
| 1259 |
+
if self.calculate_KL:
|
| 1260 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
| 1261 |
+
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
| 1262 |
+
else:
|
| 1263 |
+
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
| 1264 |
+
|
| 1265 |
+
# Chosen losses
|
| 1266 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1267 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1268 |
+
|
| 1269 |
+
if self.loss_type == "kto":
|
| 1270 |
+
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
| 1271 |
+
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
| 1272 |
+
elif self.loss_type == "apo_zero_unpaired":
|
| 1273 |
+
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
| 1274 |
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
| 1275 |
+
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
| 1276 |
+
|
| 1277 |
+
chosen_rewards = self.beta * chosen_logratios.detach()
|
| 1278 |
+
|
| 1279 |
+
else:
|
| 1280 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1281 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1282 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1283 |
+
|
| 1284 |
+
# Rejected losses
|
| 1285 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1286 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1287 |
+
|
| 1288 |
+
if self.loss_type == "kto":
|
| 1289 |
+
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
| 1290 |
+
elif self.loss_type == "apo_zero_unpaired":
|
| 1291 |
+
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
| 1292 |
+
|
| 1293 |
+
rejected_rewards = self.beta * rejected_logratios.detach()
|
| 1294 |
+
else:
|
| 1295 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1296 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1297 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1298 |
+
|
| 1299 |
+
losses = torch.cat(
|
| 1300 |
+
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
| 1301 |
+
0,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
return losses, chosen_rewards, rejected_rewards, kl
|
| 1305 |
+
|
| 1306 |
+
def get_batch_loss_metrics(
|
| 1307 |
+
self,
|
| 1308 |
+
model,
|
| 1309 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1310 |
+
):
|
| 1311 |
+
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
| 1312 |
+
metrics = {}
|
| 1313 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1314 |
+
|
| 1315 |
+
forward_output = self.forward(model, batch)
|
| 1316 |
+
(
|
| 1317 |
+
policy_chosen_logps,
|
| 1318 |
+
policy_rejected_logps,
|
| 1319 |
+
policy_chosen_logits,
|
| 1320 |
+
policy_rejected_logits,
|
| 1321 |
+
policy_KL_logps,
|
| 1322 |
+
) = forward_output[:5]
|
| 1323 |
+
if self.aux_loss_enabled:
|
| 1324 |
+
aux_loss = forward_output[5]
|
| 1325 |
+
|
| 1326 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
| 1327 |
+
if "reference_logps" in batch:
|
| 1328 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
| 1329 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
| 1330 |
+
|
| 1331 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
| 1332 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
| 1333 |
+
if self.calculate_KL:
|
| 1334 |
+
reference_KL_logps = batch["reference_KL_logps"]
|
| 1335 |
+
else:
|
| 1336 |
+
reference_KL_logps = None
|
| 1337 |
+
else:
|
| 1338 |
+
with torch.no_grad():
|
| 1339 |
+
if self.ref_model is None:
|
| 1340 |
+
with self.null_ref_context():
|
| 1341 |
+
(
|
| 1342 |
+
reference_chosen_logps,
|
| 1343 |
+
reference_rejected_logps,
|
| 1344 |
+
_,
|
| 1345 |
+
_,
|
| 1346 |
+
reference_KL_logps,
|
| 1347 |
+
) = self.forward(self.model, batch)[:5]
|
| 1348 |
+
else:
|
| 1349 |
+
(
|
| 1350 |
+
reference_chosen_logps,
|
| 1351 |
+
reference_rejected_logps,
|
| 1352 |
+
_,
|
| 1353 |
+
_,
|
| 1354 |
+
reference_KL_logps,
|
| 1355 |
+
) = self.forward(self.ref_model, batch)[:5]
|
| 1356 |
+
|
| 1357 |
+
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
| 1358 |
+
policy_chosen_logps,
|
| 1359 |
+
policy_rejected_logps,
|
| 1360 |
+
policy_KL_logps,
|
| 1361 |
+
reference_chosen_logps,
|
| 1362 |
+
reference_rejected_logps,
|
| 1363 |
+
reference_KL_logps,
|
| 1364 |
+
)
|
| 1365 |
+
metrics["kl"] = kl.item()
|
| 1366 |
+
|
| 1367 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
| 1368 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
| 1369 |
+
|
| 1370 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1371 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1372 |
+
|
| 1373 |
+
if all_num_chosen > 0:
|
| 1374 |
+
metrics["rewards/chosen_sum"] = (
|
| 1375 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1376 |
+
)
|
| 1377 |
+
metrics["logps/chosen_sum"] = (
|
| 1378 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1379 |
+
)
|
| 1380 |
+
metrics["logits/chosen_sum"] = (
|
| 1381 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1382 |
+
)
|
| 1383 |
+
metrics["count/chosen"] = all_num_chosen
|
| 1384 |
+
|
| 1385 |
+
if all_num_rejected > 0:
|
| 1386 |
+
metrics["rewards/rejected_sum"] = (
|
| 1387 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1388 |
+
)
|
| 1389 |
+
metrics["logps/rejected_sum"] = (
|
| 1390 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1391 |
+
)
|
| 1392 |
+
metrics["logits/rejected_sum"] = (
|
| 1393 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1394 |
+
)
|
| 1395 |
+
metrics["count/rejected"] = all_num_rejected
|
| 1396 |
+
|
| 1397 |
+
loss = losses.nanmean()
|
| 1398 |
+
if self.aux_loss_enabled:
|
| 1399 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1400 |
+
|
| 1401 |
+
return loss, metrics
|
| 1402 |
+
|
| 1403 |
+
def compute_loss(
|
| 1404 |
+
self,
|
| 1405 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1406 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1407 |
+
return_outputs=False,
|
| 1408 |
+
num_items_in_batch=None,
|
| 1409 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1410 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1411 |
+
|
| 1412 |
+
with compute_loss_context_manager:
|
| 1413 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1414 |
+
|
| 1415 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1416 |
+
loss = loss.to(self.args.device)
|
| 1417 |
+
# force log the metrics
|
| 1418 |
+
if self.accelerator.is_main_process:
|
| 1419 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1420 |
+
|
| 1421 |
+
if return_outputs:
|
| 1422 |
+
return (loss, metrics)
|
| 1423 |
+
return loss
|
| 1424 |
+
|
| 1425 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1426 |
+
for key, value in metrics.items():
|
| 1427 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1428 |
+
|
| 1429 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 1430 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 1431 |
+
return None
|
| 1432 |
+
return SequentialSampler(self.train_dataset)
|
| 1433 |
+
|
| 1434 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1435 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1436 |
+
|
| 1437 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1438 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1439 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1440 |
+
|
| 1441 |
+
with generate_context_manager:
|
| 1442 |
+
policy_output = model.generate(
|
| 1443 |
+
input_ids=batch["prompt_input_ids"],
|
| 1444 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1445 |
+
max_length=self.max_length,
|
| 1446 |
+
do_sample=True,
|
| 1447 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
# if reference_output in batch use that otherwise use the reference model
|
| 1451 |
+
if "reference_output" in batch:
|
| 1452 |
+
reference_output = batch["reference_output"]
|
| 1453 |
+
else:
|
| 1454 |
+
if self.ref_model is None:
|
| 1455 |
+
with self.null_ref_context():
|
| 1456 |
+
reference_output = self.model.generate(
|
| 1457 |
+
input_ids=batch["prompt_input_ids"],
|
| 1458 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1459 |
+
max_length=self.max_length,
|
| 1460 |
+
do_sample=True,
|
| 1461 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1462 |
+
)
|
| 1463 |
+
else:
|
| 1464 |
+
reference_output = self.ref_model.generate(
|
| 1465 |
+
input_ids=batch["prompt_input_ids"],
|
| 1466 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1467 |
+
max_length=self.max_length,
|
| 1468 |
+
do_sample=True,
|
| 1469 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1473 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1474 |
+
|
| 1475 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1476 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1477 |
+
|
| 1478 |
+
return policy_output_decoded, reference_output_decoded
|
| 1479 |
+
|
| 1480 |
+
def prediction_step(
|
| 1481 |
+
self,
|
| 1482 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1483 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1484 |
+
prediction_loss_only: bool,
|
| 1485 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1486 |
+
):
|
| 1487 |
+
if ignore_keys is None:
|
| 1488 |
+
if hasattr(model, "config"):
|
| 1489 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1490 |
+
else:
|
| 1491 |
+
ignore_keys = []
|
| 1492 |
+
|
| 1493 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1494 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1495 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1496 |
+
|
| 1497 |
+
# force log the metrics
|
| 1498 |
+
if self.accelerator.is_main_process:
|
| 1499 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1500 |
+
|
| 1501 |
+
if prediction_loss_only:
|
| 1502 |
+
return (loss.detach(), None, None)
|
| 1503 |
+
|
| 1504 |
+
# logits for the chosen and rejected samples from model
|
| 1505 |
+
logits_dict = {
|
| 1506 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
| 1507 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
| 1508 |
+
}
|
| 1509 |
+
logits = torch.tensor(
|
| 1510 |
+
[v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
|
| 1511 |
+
)
|
| 1512 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1513 |
+
|
| 1514 |
+
return (loss.detach(), logits, labels)
|
| 1515 |
+
|
| 1516 |
+
def evaluation_loop(
|
| 1517 |
+
self,
|
| 1518 |
+
dataloader: DataLoader,
|
| 1519 |
+
description: str,
|
| 1520 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1521 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1522 |
+
metric_key_prefix: str = "eval",
|
| 1523 |
+
) -> EvalLoopOutput:
|
| 1524 |
+
"""
|
| 1525 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1526 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1527 |
+
|
| 1528 |
+
Works both with or without labels.
|
| 1529 |
+
"""
|
| 1530 |
+
|
| 1531 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1532 |
+
if self.generate_during_eval:
|
| 1533 |
+
# Generate random indices within the range of the total number of samples
|
| 1534 |
+
num_samples = len(dataloader.dataset)
|
| 1535 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1536 |
+
|
| 1537 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1538 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1539 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1540 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1541 |
+
|
| 1542 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
| 1543 |
+
target_batch = {
|
| 1544 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
| 1545 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
| 1546 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
| 1547 |
+
}
|
| 1548 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1549 |
+
|
| 1550 |
+
table = pd.DataFrame(
|
| 1551 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
| 1552 |
+
data=[
|
| 1553 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1554 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
| 1555 |
+
],
|
| 1556 |
+
)
|
| 1557 |
+
if "wandb" in self.args.report_to:
|
| 1558 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1559 |
+
|
| 1560 |
+
if "comet_ml" in self.args.report_to:
|
| 1561 |
+
log_table_to_comet_experiment(
|
| 1562 |
+
name="game_log.csv",
|
| 1563 |
+
table=table,
|
| 1564 |
+
)
|
| 1565 |
+
|
| 1566 |
+
# Base evaluation
|
| 1567 |
+
initial_output = super().evaluation_loop(
|
| 1568 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1569 |
+
)
|
| 1570 |
+
|
| 1571 |
+
return initial_output
|
| 1572 |
+
|
| 1573 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1574 |
+
"""
|
| 1575 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1576 |
+
|
| 1577 |
+
Args:
|
| 1578 |
+
logs (`dict[str, float]`):
|
| 1579 |
+
The values to log.
|
| 1580 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1581 |
+
Start time of the training.
|
| 1582 |
+
"""
|
| 1583 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1584 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1585 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
| 1586 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1587 |
+
# accumulate average metrics from sums and lengths
|
| 1588 |
+
for split in ["chosen", "rejected"]:
|
| 1589 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1590 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1591 |
+
for metric in ["rewards", "logps", "logits"]:
|
| 1592 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
| 1593 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1594 |
+
/ count_sum
|
| 1595 |
+
)
|
| 1596 |
+
# delete obsolete metric
|
| 1597 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1598 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1599 |
+
# calculate reward margin
|
| 1600 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1601 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1602 |
+
# Add averaged stored metrics to logs
|
| 1603 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1604 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1605 |
+
del self._stored_metrics[train_eval]
|
| 1606 |
+
|
| 1607 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1608 |
+
return super().log(logs, start_time)
|
| 1609 |
+
else: # transformers<=4.46
|
| 1610 |
+
return super().log(logs)
|
| 1611 |
+
|
| 1612 |
+
def create_model_card(
|
| 1613 |
+
self,
|
| 1614 |
+
model_name: Optional[str] = None,
|
| 1615 |
+
dataset_name: Optional[str] = None,
|
| 1616 |
+
tags: Union[str, list[str], None] = None,
|
| 1617 |
+
):
|
| 1618 |
+
"""
|
| 1619 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1620 |
+
|
| 1621 |
+
Args:
|
| 1622 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1623 |
+
Name of the model.
|
| 1624 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1625 |
+
Name of the dataset used for training.
|
| 1626 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1627 |
+
Tags to be associated with the model card.
|
| 1628 |
+
"""
|
| 1629 |
+
if not self.is_world_process_zero():
|
| 1630 |
+
return
|
| 1631 |
+
|
| 1632 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1633 |
+
base_model = self.model.config._name_or_path
|
| 1634 |
+
else:
|
| 1635 |
+
base_model = None
|
| 1636 |
+
|
| 1637 |
+
tags = tags or []
|
| 1638 |
+
if isinstance(tags, str):
|
| 1639 |
+
tags = [tags]
|
| 1640 |
+
|
| 1641 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1642 |
+
tags.append("unsloth")
|
| 1643 |
+
|
| 1644 |
+
citation = textwrap.dedent("""\
|
| 1645 |
+
@article{ethayarajh2024kto,
|
| 1646 |
+
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
| 1647 |
+
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
| 1648 |
+
year = 2024,
|
| 1649 |
+
eprint = {arXiv:2402.01306},
|
| 1650 |
+
}""")
|
| 1651 |
+
|
| 1652 |
+
model_card = generate_model_card(
|
| 1653 |
+
base_model=base_model,
|
| 1654 |
+
model_name=model_name,
|
| 1655 |
+
hub_model_id=self.hub_model_id,
|
| 1656 |
+
dataset_name=dataset_name,
|
| 1657 |
+
tags=tags,
|
| 1658 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1659 |
+
comet_url=get_comet_experiment_url(),
|
| 1660 |
+
trainer_name="KTO",
|
| 1661 |
+
trainer_citation=citation,
|
| 1662 |
+
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
|
| 1663 |
+
paper_id="2402.01306",
|
| 1664 |
+
)
|
| 1665 |
+
|
| 1666 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1667 |
+
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
| 1668 |
+
"""
|
| 1669 |
+
|
| 1670 |
+
Initialize KTOTrainer.
|
| 1671 |
+
|
| 1672 |
+
Args:
|
| 1673 |
+
model (`transformers.PreTrainedModel`):
|
| 1674 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1675 |
+
ref_model (`PreTrainedModelWrapper`):
|
| 1676 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 1677 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 1678 |
+
args (`KTOConfig`):
|
| 1679 |
+
The arguments to use for training.
|
| 1680 |
+
train_dataset (`datasets.Dataset`):
|
| 1681 |
+
The dataset to use for training.
|
| 1682 |
+
eval_dataset (`datasets.Dataset`):
|
| 1683 |
+
The dataset to use for evaluation.
|
| 1684 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1685 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1686 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1687 |
+
reuse the fine-tuned model.
|
| 1688 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
| 1689 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1690 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1691 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1692 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1693 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1694 |
+
The callbacks to use for training.
|
| 1695 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1696 |
+
The optimizer and scheduler to use for training.
|
| 1697 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1698 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1699 |
+
peft_config (`dict`, defaults to `None`):
|
| 1700 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1701 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1702 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1703 |
+
a dictionary string to metric values.
|
| 1704 |
+
model_adapter_name (`str`, defaults to `None`):
|
| 1705 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 1706 |
+
ref_adapter_name (`str`, defaults to `None`):
|
| 1707 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 1708 |
+
|
| 1709 |
+
"""
|
| 1710 |
+
def __init__(
|
| 1711 |
+
self,
|
| 1712 |
+
model = None,
|
| 1713 |
+
ref_model = None,
|
| 1714 |
+
args = None,
|
| 1715 |
+
train_dataset = None,
|
| 1716 |
+
eval_dataset = None,
|
| 1717 |
+
processing_class = None,
|
| 1718 |
+
data_collator = None,
|
| 1719 |
+
model_init = None,
|
| 1720 |
+
callbacks = None,
|
| 1721 |
+
preprocess_logits_for_metrics = None,
|
| 1722 |
+
peft_config = None,
|
| 1723 |
+
compute_metrics = None,
|
| 1724 |
+
model_adapter_name = None,
|
| 1725 |
+
ref_adapter_name = None,
|
| 1726 |
+
**kwargs
|
| 1727 |
+
):
|
| 1728 |
+
if args is None: args = UnslothKTOConfig()
|
| 1729 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1730 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1731 |
+
force_float32 = False
|
| 1732 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1733 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1734 |
+
force_float32 = True
|
| 1735 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1736 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1737 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1738 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1739 |
+
dtype = _get_dtype(dtype)
|
| 1740 |
+
float16 = dtype == torch.float16
|
| 1741 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1742 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1743 |
+
if force_float32:
|
| 1744 |
+
args.fp16 = False
|
| 1745 |
+
args.bf16 = False
|
| 1746 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1747 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1748 |
+
args.fp16 = float16
|
| 1749 |
+
args.bf16 = not float16
|
| 1750 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1751 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1752 |
+
args.eval_strategy = 'steps'
|
| 1753 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1754 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1755 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1756 |
+
from transformers import __version__ as transformers_version
|
| 1757 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1758 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1759 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1760 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1761 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1762 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1763 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1764 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1765 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1766 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1767 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1768 |
+
if force_float32:
|
| 1769 |
+
args.bf16_full_eval = False
|
| 1770 |
+
args.fp16_full_eval = False
|
| 1771 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1772 |
+
args.bf16_full_eval = True
|
| 1773 |
+
args.fp16_full_eval = False
|
| 1774 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1775 |
+
args.bf16_full_eval = args.bf16
|
| 1776 |
+
args.fp16_full_eval = args.fp16
|
| 1777 |
+
_output_logits = False
|
| 1778 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1779 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1780 |
+
if _output_logits:
|
| 1781 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1782 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1783 |
+
pass
|
| 1784 |
+
else:
|
| 1785 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1786 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1787 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1788 |
+
max_seq_length = model.max_seq_length
|
| 1789 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1790 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1791 |
+
model.for_training()
|
| 1792 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1793 |
+
if 'processing_class' in locals():
|
| 1794 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1795 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1796 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1797 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1798 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1799 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1800 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1801 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1802 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1803 |
+
else:
|
| 1804 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1805 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1806 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1807 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1808 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1809 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1810 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1811 |
+
else:
|
| 1812 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1813 |
+
other_metrics = []
|
| 1814 |
+
|
| 1815 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1816 |
+
PatchRLStatistics('kto_trainer', other_metrics)
|
| 1817 |
+
|
| 1818 |
+
super().__init__(
|
| 1819 |
+
model = model,
|
| 1820 |
+
ref_model = ref_model,
|
| 1821 |
+
args = args,
|
| 1822 |
+
train_dataset = train_dataset,
|
| 1823 |
+
eval_dataset = eval_dataset,
|
| 1824 |
+
processing_class = processing_class,
|
| 1825 |
+
data_collator = data_collator,
|
| 1826 |
+
model_init = model_init,
|
| 1827 |
+
callbacks = callbacks,
|
| 1828 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1829 |
+
peft_config = peft_config,
|
| 1830 |
+
compute_metrics = compute_metrics,
|
| 1831 |
+
model_adapter_name = model_adapter_name,
|
| 1832 |
+
ref_adapter_name = ref_adapter_name,**kwargs)
|
| 1833 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1834 |
+
self.neftune_hook_handle.remove()
|
| 1835 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1836 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1837 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1838 |
+
pass
|
| 1839 |
+
|
| 1840 |
+
pass
|
unsloth_compiled_cache/UnslothNashMDTrainer.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothNashMDConfig(NashMDConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`NashMDTrainer`].
|
| 47 |
+
|
| 48 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 49 |
+
|
| 50 |
+
Parameters:
|
| 51 |
+
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
| 52 |
+
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
| 53 |
+
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
| 54 |
+
epochs.
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 58 |
+
default = None,
|
| 59 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 60 |
+
)
|
| 61 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 62 |
+
default = -1,
|
| 63 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 64 |
+
)
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
output_dir = None,
|
| 68 |
+
overwrite_output_dir = None,
|
| 69 |
+
do_train = False,
|
| 70 |
+
do_eval = False,
|
| 71 |
+
do_predict = False,
|
| 72 |
+
eval_strategy = 'no',
|
| 73 |
+
prediction_loss_only = False,
|
| 74 |
+
per_device_train_batch_size = 4,
|
| 75 |
+
per_device_eval_batch_size = 4,
|
| 76 |
+
per_gpu_train_batch_size = None,
|
| 77 |
+
per_gpu_eval_batch_size = None,
|
| 78 |
+
gradient_accumulation_steps = 2,
|
| 79 |
+
eval_accumulation_steps = 2,
|
| 80 |
+
eval_delay = 0,
|
| 81 |
+
torch_empty_cache_steps = 250,
|
| 82 |
+
learning_rate = 5e-05,
|
| 83 |
+
weight_decay = 0.01,
|
| 84 |
+
adam_beta1 = 0.9,
|
| 85 |
+
adam_beta2 = 0.999,
|
| 86 |
+
adam_epsilon = 1e-08,
|
| 87 |
+
max_grad_norm = 1.0,
|
| 88 |
+
num_train_epochs = 3.0,
|
| 89 |
+
max_steps = -1,
|
| 90 |
+
lr_scheduler_type = 'linear',
|
| 91 |
+
warmup_ratio = 0.1,
|
| 92 |
+
warmup_steps = 0,
|
| 93 |
+
log_level = 'passive',
|
| 94 |
+
log_level_replica = 'warning',
|
| 95 |
+
log_on_each_node = True,
|
| 96 |
+
logging_dir = None,
|
| 97 |
+
logging_strategy = 'steps',
|
| 98 |
+
logging_first_step = False,
|
| 99 |
+
logging_steps = 1,
|
| 100 |
+
logging_nan_inf_filter = False,
|
| 101 |
+
save_strategy = 'steps',
|
| 102 |
+
save_steps = 500,
|
| 103 |
+
save_total_limit = None,
|
| 104 |
+
save_safetensors = True,
|
| 105 |
+
save_on_each_node = False,
|
| 106 |
+
save_only_model = False,
|
| 107 |
+
restore_callback_states_from_checkpoint = False,
|
| 108 |
+
no_cuda = False,
|
| 109 |
+
use_cpu = False,
|
| 110 |
+
use_mps_device = False,
|
| 111 |
+
seed = 3407,
|
| 112 |
+
data_seed = 3407,
|
| 113 |
+
jit_mode_eval = False,
|
| 114 |
+
use_ipex = False,
|
| 115 |
+
bf16 = False,
|
| 116 |
+
fp16 = False,
|
| 117 |
+
fp16_opt_level = 'O1',
|
| 118 |
+
half_precision_backend = 'auto',
|
| 119 |
+
bf16_full_eval = False,
|
| 120 |
+
fp16_full_eval = False,
|
| 121 |
+
tf32 = None,
|
| 122 |
+
local_rank = -1,
|
| 123 |
+
ddp_backend = None,
|
| 124 |
+
tpu_num_cores = None,
|
| 125 |
+
tpu_metrics_debug = False,
|
| 126 |
+
debug = '',
|
| 127 |
+
dataloader_drop_last = False,
|
| 128 |
+
eval_steps = None,
|
| 129 |
+
dataloader_num_workers = 0,
|
| 130 |
+
dataloader_prefetch_factor = None,
|
| 131 |
+
past_index = -1,
|
| 132 |
+
run_name = None,
|
| 133 |
+
disable_tqdm = None,
|
| 134 |
+
remove_unused_columns = True,
|
| 135 |
+
label_names = None,
|
| 136 |
+
load_best_model_at_end = False,
|
| 137 |
+
metric_for_best_model = None,
|
| 138 |
+
greater_is_better = None,
|
| 139 |
+
ignore_data_skip = False,
|
| 140 |
+
fsdp = '',
|
| 141 |
+
fsdp_min_num_params = 0,
|
| 142 |
+
fsdp_config = None,
|
| 143 |
+
tp_size = 0,
|
| 144 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 145 |
+
accelerator_config = None,
|
| 146 |
+
deepspeed = None,
|
| 147 |
+
label_smoothing_factor = 0.0,
|
| 148 |
+
optim = 'adamw_8bit',
|
| 149 |
+
optim_args = None,
|
| 150 |
+
adafactor = False,
|
| 151 |
+
group_by_length = False,
|
| 152 |
+
length_column_name = 'length',
|
| 153 |
+
report_to = None,
|
| 154 |
+
ddp_find_unused_parameters = None,
|
| 155 |
+
ddp_bucket_cap_mb = None,
|
| 156 |
+
ddp_broadcast_buffers = None,
|
| 157 |
+
dataloader_pin_memory = True,
|
| 158 |
+
dataloader_persistent_workers = False,
|
| 159 |
+
skip_memory_metrics = True,
|
| 160 |
+
use_legacy_prediction_loop = False,
|
| 161 |
+
push_to_hub = False,
|
| 162 |
+
resume_from_checkpoint = None,
|
| 163 |
+
hub_model_id = None,
|
| 164 |
+
hub_strategy = 'every_save',
|
| 165 |
+
hub_token = None,
|
| 166 |
+
hub_private_repo = None,
|
| 167 |
+
hub_always_push = False,
|
| 168 |
+
gradient_checkpointing = False,
|
| 169 |
+
gradient_checkpointing_kwargs = None,
|
| 170 |
+
include_inputs_for_metrics = False,
|
| 171 |
+
eval_do_concat_batches = True,
|
| 172 |
+
fp16_backend = 'auto',
|
| 173 |
+
evaluation_strategy = None,
|
| 174 |
+
push_to_hub_model_id = None,
|
| 175 |
+
push_to_hub_organization = None,
|
| 176 |
+
push_to_hub_token = None,
|
| 177 |
+
mp_parameters = '',
|
| 178 |
+
auto_find_batch_size = False,
|
| 179 |
+
full_determinism = False,
|
| 180 |
+
torchdynamo = None,
|
| 181 |
+
ray_scope = 'last',
|
| 182 |
+
ddp_timeout = 1800,
|
| 183 |
+
torch_compile = False,
|
| 184 |
+
torch_compile_backend = None,
|
| 185 |
+
torch_compile_mode = None,
|
| 186 |
+
dispatch_batches = None,
|
| 187 |
+
split_batches = None,
|
| 188 |
+
include_tokens_per_second = False,
|
| 189 |
+
include_num_input_tokens_seen = False,
|
| 190 |
+
neftune_noise_alpha = None,
|
| 191 |
+
optim_target_modules = None,
|
| 192 |
+
batch_eval_metrics = False,
|
| 193 |
+
eval_on_start = False,
|
| 194 |
+
use_liger_kernel = False,
|
| 195 |
+
eval_use_gather_object = False,
|
| 196 |
+
average_tokens_across_devices = False,
|
| 197 |
+
reward_model_path = None,
|
| 198 |
+
judge = None,
|
| 199 |
+
max_new_tokens = 64,
|
| 200 |
+
max_length = 512,
|
| 201 |
+
temperature = 0.9,
|
| 202 |
+
missing_eos_penalty = None,
|
| 203 |
+
loss_type = 'sigmoid',
|
| 204 |
+
dataset_num_proc = None,
|
| 205 |
+
disable_dropout = True,
|
| 206 |
+
use_vllm = False,
|
| 207 |
+
ds3_gather_for_generation = True,
|
| 208 |
+
vllm_sampling_params = None,
|
| 209 |
+
unsloth_num_chunks = -1,
|
| 210 |
+
**kwargs,
|
| 211 |
+
):
|
| 212 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 213 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 214 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 215 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 216 |
+
save_strategy = 'no'
|
| 217 |
+
if dataset_num_proc is None:
|
| 218 |
+
from multiprocessing import cpu_count
|
| 219 |
+
dataset_num_proc = cpu_count()
|
| 220 |
+
|
| 221 |
+
super().__init__(
|
| 222 |
+
output_dir = output_dir,
|
| 223 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 224 |
+
do_train = do_train,
|
| 225 |
+
do_eval = do_eval,
|
| 226 |
+
do_predict = do_predict,
|
| 227 |
+
eval_strategy = eval_strategy,
|
| 228 |
+
prediction_loss_only = prediction_loss_only,
|
| 229 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 230 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 231 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 232 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 233 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 234 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 235 |
+
eval_delay = eval_delay,
|
| 236 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 237 |
+
learning_rate = learning_rate,
|
| 238 |
+
weight_decay = weight_decay,
|
| 239 |
+
adam_beta1 = adam_beta1,
|
| 240 |
+
adam_beta2 = adam_beta2,
|
| 241 |
+
adam_epsilon = adam_epsilon,
|
| 242 |
+
max_grad_norm = max_grad_norm,
|
| 243 |
+
num_train_epochs = num_train_epochs,
|
| 244 |
+
max_steps = max_steps,
|
| 245 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 246 |
+
warmup_ratio = warmup_ratio,
|
| 247 |
+
warmup_steps = warmup_steps,
|
| 248 |
+
log_level = log_level,
|
| 249 |
+
log_level_replica = log_level_replica,
|
| 250 |
+
log_on_each_node = log_on_each_node,
|
| 251 |
+
logging_dir = logging_dir,
|
| 252 |
+
logging_strategy = logging_strategy,
|
| 253 |
+
logging_first_step = logging_first_step,
|
| 254 |
+
logging_steps = logging_steps,
|
| 255 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 256 |
+
save_strategy = save_strategy,
|
| 257 |
+
save_steps = save_steps,
|
| 258 |
+
save_total_limit = save_total_limit,
|
| 259 |
+
save_safetensors = save_safetensors,
|
| 260 |
+
save_on_each_node = save_on_each_node,
|
| 261 |
+
save_only_model = save_only_model,
|
| 262 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 263 |
+
no_cuda = no_cuda,
|
| 264 |
+
use_cpu = use_cpu,
|
| 265 |
+
use_mps_device = use_mps_device,
|
| 266 |
+
seed = seed,
|
| 267 |
+
data_seed = data_seed,
|
| 268 |
+
jit_mode_eval = jit_mode_eval,
|
| 269 |
+
use_ipex = use_ipex,
|
| 270 |
+
bf16 = bf16,
|
| 271 |
+
fp16 = fp16,
|
| 272 |
+
fp16_opt_level = fp16_opt_level,
|
| 273 |
+
half_precision_backend = half_precision_backend,
|
| 274 |
+
bf16_full_eval = bf16_full_eval,
|
| 275 |
+
fp16_full_eval = fp16_full_eval,
|
| 276 |
+
tf32 = tf32,
|
| 277 |
+
local_rank = local_rank,
|
| 278 |
+
ddp_backend = ddp_backend,
|
| 279 |
+
tpu_num_cores = tpu_num_cores,
|
| 280 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 281 |
+
debug = debug,
|
| 282 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 283 |
+
eval_steps = eval_steps,
|
| 284 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 285 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 286 |
+
past_index = past_index,
|
| 287 |
+
run_name = run_name,
|
| 288 |
+
disable_tqdm = disable_tqdm,
|
| 289 |
+
remove_unused_columns = remove_unused_columns,
|
| 290 |
+
label_names = label_names,
|
| 291 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 292 |
+
metric_for_best_model = metric_for_best_model,
|
| 293 |
+
greater_is_better = greater_is_better,
|
| 294 |
+
ignore_data_skip = ignore_data_skip,
|
| 295 |
+
fsdp = fsdp,
|
| 296 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 297 |
+
fsdp_config = fsdp_config,
|
| 298 |
+
tp_size = tp_size,
|
| 299 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 300 |
+
accelerator_config = accelerator_config,
|
| 301 |
+
deepspeed = deepspeed,
|
| 302 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 303 |
+
optim = optim,
|
| 304 |
+
optim_args = optim_args,
|
| 305 |
+
adafactor = adafactor,
|
| 306 |
+
group_by_length = group_by_length,
|
| 307 |
+
length_column_name = length_column_name,
|
| 308 |
+
report_to = report_to,
|
| 309 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 310 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 311 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 312 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 313 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 314 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 315 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 316 |
+
push_to_hub = push_to_hub,
|
| 317 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 318 |
+
hub_model_id = hub_model_id,
|
| 319 |
+
hub_strategy = hub_strategy,
|
| 320 |
+
hub_token = hub_token,
|
| 321 |
+
hub_private_repo = hub_private_repo,
|
| 322 |
+
hub_always_push = hub_always_push,
|
| 323 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 324 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 325 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 326 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 327 |
+
fp16_backend = fp16_backend,
|
| 328 |
+
evaluation_strategy = evaluation_strategy,
|
| 329 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 330 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 331 |
+
push_to_hub_token = push_to_hub_token,
|
| 332 |
+
mp_parameters = mp_parameters,
|
| 333 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 334 |
+
full_determinism = full_determinism,
|
| 335 |
+
torchdynamo = torchdynamo,
|
| 336 |
+
ray_scope = ray_scope,
|
| 337 |
+
ddp_timeout = ddp_timeout,
|
| 338 |
+
torch_compile = torch_compile,
|
| 339 |
+
torch_compile_backend = torch_compile_backend,
|
| 340 |
+
torch_compile_mode = torch_compile_mode,
|
| 341 |
+
dispatch_batches = dispatch_batches,
|
| 342 |
+
split_batches = split_batches,
|
| 343 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 344 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 345 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 346 |
+
optim_target_modules = optim_target_modules,
|
| 347 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 348 |
+
eval_on_start = eval_on_start,
|
| 349 |
+
use_liger_kernel = use_liger_kernel,
|
| 350 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 351 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 352 |
+
reward_model_path = reward_model_path,
|
| 353 |
+
judge = judge,
|
| 354 |
+
max_new_tokens = max_new_tokens,
|
| 355 |
+
max_length = max_length,
|
| 356 |
+
temperature = temperature,
|
| 357 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 358 |
+
loss_type = loss_type,
|
| 359 |
+
dataset_num_proc = dataset_num_proc,
|
| 360 |
+
disable_dropout = disable_dropout,
|
| 361 |
+
use_vllm = use_vllm,
|
| 362 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 363 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 364 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 365 |
+
pass
|
| 366 |
+
|
| 367 |
+
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
| 368 |
+
r""""""
|
| 369 |
+
|
| 370 |
+
_tag_names = ["trl", "nash-md"]
|
| 371 |
+
|
| 372 |
+
def __init__(
|
| 373 |
+
self,
|
| 374 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 375 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 376 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 377 |
+
judge: Optional[BasePairwiseJudge] = None,
|
| 378 |
+
args: Optional[NashMDConfig] = None,
|
| 379 |
+
data_collator: Optional[Callable] = None,
|
| 380 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 381 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 382 |
+
processing_class: Optional[
|
| 383 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 384 |
+
] = None,
|
| 385 |
+
peft_config: Optional[dict] = None,
|
| 386 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 387 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 388 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 389 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 390 |
+
) -> None:
|
| 391 |
+
super().__init__(
|
| 392 |
+
model=model,
|
| 393 |
+
ref_model=ref_model,
|
| 394 |
+
reward_model=reward_model,
|
| 395 |
+
judge=judge,
|
| 396 |
+
args=args,
|
| 397 |
+
data_collator=data_collator,
|
| 398 |
+
train_dataset=train_dataset,
|
| 399 |
+
eval_dataset=eval_dataset,
|
| 400 |
+
processing_class=processing_class,
|
| 401 |
+
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
| 402 |
+
peft_config=peft_config,
|
| 403 |
+
compute_metrics=compute_metrics,
|
| 404 |
+
callbacks=callbacks,
|
| 405 |
+
optimizers=optimizers,
|
| 406 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
self._mixture_coef = self.args.mixture_coef
|
| 410 |
+
|
| 411 |
+
# Overwrite the stats dictionary to include NashMD specific statistics
|
| 412 |
+
self.stats = {
|
| 413 |
+
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
| 414 |
+
# Add "mixture_coef"
|
| 415 |
+
"loss/kl": [],
|
| 416 |
+
"objective/entropy": [],
|
| 417 |
+
"loss/score": [],
|
| 418 |
+
"rewards/probabilities": [],
|
| 419 |
+
"rewards/accuracies": [],
|
| 420 |
+
"rewards/margins": [],
|
| 421 |
+
"logps/chosen": [],
|
| 422 |
+
"logps/rejected": [],
|
| 423 |
+
"val/model_contain_eos_token": [],
|
| 424 |
+
"val/ref_contain_eos_token": [],
|
| 425 |
+
"beta": [],
|
| 426 |
+
"mixture_coef": [],
|
| 427 |
+
}
|
| 428 |
+
if self.reward_model is not None:
|
| 429 |
+
self.stats["rewards/chosen"] = []
|
| 430 |
+
self.stats["rewards/rejected"] = []
|
| 431 |
+
|
| 432 |
+
@property
|
| 433 |
+
def mixture_coef(self):
|
| 434 |
+
if isinstance(self._mixture_coef, list):
|
| 435 |
+
epoch = self.state.epoch
|
| 436 |
+
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
| 437 |
+
else:
|
| 438 |
+
return self._mixture_coef
|
| 439 |
+
|
| 440 |
+
def _generate_completions(self, model, prompts):
|
| 441 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 442 |
+
model_output = unwrapped_model.generate(
|
| 443 |
+
input_ids=prompts["input_ids"],
|
| 444 |
+
attention_mask=prompts["attention_mask"],
|
| 445 |
+
generation_config=self.generation_config,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
ref_model = model if self.ref_model is None else self.ref_model
|
| 449 |
+
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
| 450 |
+
mixture_model = GeometricMixtureWrapper(
|
| 451 |
+
model=unwrapped_model,
|
| 452 |
+
ref_model=unwrapped_ref_model,
|
| 453 |
+
generation_config=self.generation_config,
|
| 454 |
+
mixture_coef=self.mixture_coef,
|
| 455 |
+
device=self.accelerator.device,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
mixture_output = mixture_model.generate(
|
| 459 |
+
input_ids=prompts["input_ids"],
|
| 460 |
+
attention_mask=prompts["attention_mask"],
|
| 461 |
+
generation_config=self.generation_config,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
return model_output, mixture_output
|
| 465 |
+
|
| 466 |
+
def _process_completions(self, model_output, mixture_output, prompts):
|
| 467 |
+
context_length = prompts["input_ids"].shape[1]
|
| 468 |
+
|
| 469 |
+
# Process model completions
|
| 470 |
+
model_completion_ids = model_output[:, context_length:]
|
| 471 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
| 472 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 473 |
+
)
|
| 474 |
+
model_data = {
|
| 475 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 476 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 477 |
+
"raw": prompts["raw"],
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
# Process reference model completions
|
| 481 |
+
mixture_completion_ids = mixture_output[:, context_length:]
|
| 482 |
+
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
| 483 |
+
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 484 |
+
)
|
| 485 |
+
mixture_data = {
|
| 486 |
+
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
| 487 |
+
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
| 488 |
+
"raw": prompts["raw"],
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
return model_data, mixture_data
|
| 492 |
+
|
| 493 |
+
def _compute_rewards(self, model_data, mixture_data, context_length):
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
_, model_scores, _ = get_reward(
|
| 496 |
+
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 497 |
+
)
|
| 498 |
+
_, mixture_scores, _ = get_reward(
|
| 499 |
+
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Apply EOS penalty if needed
|
| 503 |
+
if self.args.missing_eos_penalty is not None:
|
| 504 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 505 |
+
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 506 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 507 |
+
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
| 508 |
+
|
| 509 |
+
return model_scores, mixture_scores
|
| 510 |
+
|
| 511 |
+
def _compute_judge(self, model_data, mixture_data, context_length):
|
| 512 |
+
prompts = model_data["raw"]
|
| 513 |
+
model_data_completions = self.processing_class.batch_decode(
|
| 514 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 515 |
+
)
|
| 516 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 517 |
+
|
| 518 |
+
mixture_data_completions = self.processing_class.batch_decode(
|
| 519 |
+
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 520 |
+
)
|
| 521 |
+
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
| 522 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 523 |
+
model_data_completions = [
|
| 524 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 525 |
+
]
|
| 526 |
+
environment = jinja2.Environment()
|
| 527 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 528 |
+
prompts = [template.render(messages=message) for message in prompts]
|
| 529 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 530 |
+
|
| 531 |
+
mixture_data_completions = [
|
| 532 |
+
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
| 533 |
+
]
|
| 534 |
+
mixture_data_completions = [
|
| 535 |
+
template.render(messages=completion) for completion in mixture_data_completions
|
| 536 |
+
]
|
| 537 |
+
|
| 538 |
+
probability = self.judge.judge(
|
| 539 |
+
prompts,
|
| 540 |
+
list(zip(model_data_completions, mixture_data_completions)),
|
| 541 |
+
return_scores=True,
|
| 542 |
+
)
|
| 543 |
+
return torch.tensor(probability, device=model_data["input_ids"].device)
|
| 544 |
+
|
| 545 |
+
def _compute_logprobs(self, model, model_data, context_length):
|
| 546 |
+
def compute_logprobs_for_data(m, data):
|
| 547 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 548 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 549 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 550 |
+
return token_logprobs
|
| 551 |
+
|
| 552 |
+
# Compute logprobs for model completions under the model
|
| 553 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 554 |
+
|
| 555 |
+
# Compute logprobs of model completions under the reference model
|
| 556 |
+
with torch.no_grad():
|
| 557 |
+
if self.ref_model is None:
|
| 558 |
+
with model.disable_adapter():
|
| 559 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 560 |
+
else:
|
| 561 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 562 |
+
|
| 563 |
+
# Mask padding tokens
|
| 564 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 565 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 566 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 567 |
+
|
| 568 |
+
return (model_logprobs_model_data, ref_logprobs_model_data)
|
| 569 |
+
|
| 570 |
+
def _compute_losses(
|
| 571 |
+
self,
|
| 572 |
+
model_logprobs_model_data,
|
| 573 |
+
ref_logprobs_model_data,
|
| 574 |
+
probability,
|
| 575 |
+
):
|
| 576 |
+
# reinforce score where 0.5 is a control variate
|
| 577 |
+
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
| 578 |
+
|
| 579 |
+
# kl divergence via reinforce
|
| 580 |
+
with torch.no_grad():
|
| 581 |
+
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
| 582 |
+
kl_div_log = log_ratio.sum(1)
|
| 583 |
+
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
| 584 |
+
|
| 585 |
+
# final loss
|
| 586 |
+
loss = self.beta * kl_div_loss - score
|
| 587 |
+
|
| 588 |
+
return loss.mean(), score, kl_div_log
|
| 589 |
+
|
| 590 |
+
def _log_statistics(
|
| 591 |
+
self,
|
| 592 |
+
model_data,
|
| 593 |
+
mixture_data,
|
| 594 |
+
model_logprobs_model_data,
|
| 595 |
+
ref_logprobs_model_data,
|
| 596 |
+
probability,
|
| 597 |
+
score,
|
| 598 |
+
kl_div,
|
| 599 |
+
context_length,
|
| 600 |
+
model_scores=None,
|
| 601 |
+
mixture_scores=None,
|
| 602 |
+
):
|
| 603 |
+
# Helper function to gather and compute mean
|
| 604 |
+
def gather_mean(tensor):
|
| 605 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 606 |
+
|
| 607 |
+
# Log score
|
| 608 |
+
self.stats["loss/score"].append(gather_mean(score))
|
| 609 |
+
# Log KL divergence
|
| 610 |
+
self.stats["loss/kl"].append(gather_mean(kl_div))
|
| 611 |
+
|
| 612 |
+
# Log logprobs
|
| 613 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 614 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 615 |
+
|
| 616 |
+
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
| 617 |
+
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
| 618 |
+
|
| 619 |
+
# Log rewards
|
| 620 |
+
if self.reward_model is not None:
|
| 621 |
+
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
| 622 |
+
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
| 623 |
+
|
| 624 |
+
# Log probabilities
|
| 625 |
+
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
| 626 |
+
|
| 627 |
+
# Calculate entropy for model data
|
| 628 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 629 |
+
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
| 630 |
+
|
| 631 |
+
# Calculate margins
|
| 632 |
+
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
| 633 |
+
self.stats["rewards/margins"].append(gather_mean(margin))
|
| 634 |
+
|
| 635 |
+
# Calculate accuracy
|
| 636 |
+
accuracy = (margin > 0).float()
|
| 637 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
| 638 |
+
|
| 639 |
+
# Log EOS token statistics
|
| 640 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 641 |
+
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 642 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 643 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
| 644 |
+
|
| 645 |
+
# Log beta and mixture coef
|
| 646 |
+
self.stats["beta"].append(self.beta)
|
| 647 |
+
self.stats["mixture_coef"].append(self.mixture_coef)
|
| 648 |
+
|
| 649 |
+
def training_step(
|
| 650 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 651 |
+
) -> torch.Tensor:
|
| 652 |
+
model.train()
|
| 653 |
+
|
| 654 |
+
# Apply chat template and tokenize the input
|
| 655 |
+
batch_size = len(next(iter(inputs.values())))
|
| 656 |
+
prompts = inputs["prompt"]
|
| 657 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 658 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 659 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 660 |
+
inputs = self.data_collator(inputs)
|
| 661 |
+
|
| 662 |
+
# need the prompt_ only
|
| 663 |
+
inputs = self._prepare_inputs(inputs)
|
| 664 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
| 665 |
+
prompts = {
|
| 666 |
+
"input_ids": inputs["prompt_input_ids"],
|
| 667 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
| 668 |
+
"raw": prompts,
|
| 669 |
+
}
|
| 670 |
+
del inputs
|
| 671 |
+
|
| 672 |
+
# Sample completions from both the model and the reference model
|
| 673 |
+
model_output, mixture_output = self._generate_completions(model, prompts)
|
| 674 |
+
|
| 675 |
+
# Process model completions
|
| 676 |
+
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
| 677 |
+
|
| 678 |
+
# Compute rewards
|
| 679 |
+
if self.reward_model is not None:
|
| 680 |
+
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
| 681 |
+
# probability of the model data vs the mixture data
|
| 682 |
+
probability = F.sigmoid(model_scores - mixture_scores)
|
| 683 |
+
else:
|
| 684 |
+
model_scores, mixture_scores = None, None
|
| 685 |
+
probability = self._compute_judge(model_data, mixture_data, context_length)
|
| 686 |
+
|
| 687 |
+
# Compute logprobs
|
| 688 |
+
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
| 689 |
+
|
| 690 |
+
# Compute loss
|
| 691 |
+
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
| 692 |
+
|
| 693 |
+
# Log everything
|
| 694 |
+
self._log_statistics(
|
| 695 |
+
model_data,
|
| 696 |
+
mixture_data,
|
| 697 |
+
model_logprobs_model_data.detach(),
|
| 698 |
+
ref_logprobs_model_data,
|
| 699 |
+
probability,
|
| 700 |
+
score.detach(),
|
| 701 |
+
kl_div.detach(),
|
| 702 |
+
context_length,
|
| 703 |
+
model_scores,
|
| 704 |
+
mixture_scores,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
if (
|
| 708 |
+
self.args.torch_empty_cache_steps is not None
|
| 709 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 710 |
+
):
|
| 711 |
+
empty_cache()
|
| 712 |
+
|
| 713 |
+
kwargs = {}
|
| 714 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
| 715 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 716 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
| 717 |
+
|
| 718 |
+
if self.args.n_gpu > 1:
|
| 719 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 720 |
+
|
| 721 |
+
if self.use_apex:
|
| 722 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 723 |
+
scaled_loss.backward()
|
| 724 |
+
else:
|
| 725 |
+
self.accelerator.backward(loss, **kwargs)
|
| 726 |
+
|
| 727 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 728 |
+
|
| 729 |
+
def create_model_card(
|
| 730 |
+
self,
|
| 731 |
+
model_name: Optional[str] = None,
|
| 732 |
+
dataset_name: Optional[str] = None,
|
| 733 |
+
tags: Union[str, list[str], None] = None,
|
| 734 |
+
):
|
| 735 |
+
"""
|
| 736 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 740 |
+
Name of the model.
|
| 741 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 742 |
+
Name of the dataset used for training.
|
| 743 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 744 |
+
Tags to be associated with the model card.
|
| 745 |
+
"""
|
| 746 |
+
if not self.is_world_process_zero():
|
| 747 |
+
return
|
| 748 |
+
|
| 749 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 750 |
+
base_model = self.model.config._name_or_path
|
| 751 |
+
else:
|
| 752 |
+
base_model = None
|
| 753 |
+
|
| 754 |
+
tags = tags or []
|
| 755 |
+
if isinstance(tags, str):
|
| 756 |
+
tags = [tags]
|
| 757 |
+
|
| 758 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 759 |
+
tags.append("unsloth")
|
| 760 |
+
|
| 761 |
+
citation = textwrap.dedent("""\
|
| 762 |
+
@inproceedings{munos2024nash,
|
| 763 |
+
title = {{Nash Learning from Human Feedback}},
|
| 764 |
+
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
| 765 |
+
year = 2024,
|
| 766 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 767 |
+
publisher = {OpenReview.net},
|
| 768 |
+
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
| 769 |
+
}""")
|
| 770 |
+
|
| 771 |
+
model_card = generate_model_card(
|
| 772 |
+
base_model=base_model,
|
| 773 |
+
model_name=model_name,
|
| 774 |
+
hub_model_id=self.hub_model_id,
|
| 775 |
+
dataset_name=dataset_name,
|
| 776 |
+
tags=tags,
|
| 777 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 778 |
+
comet_url=get_comet_experiment_url(),
|
| 779 |
+
trainer_name="Nash-MD",
|
| 780 |
+
trainer_citation=citation,
|
| 781 |
+
paper_title="Nash Learning from Human Feedback",
|
| 782 |
+
paper_id="2312.00886",
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 786 |
+
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
| 787 |
+
"""
|
| 788 |
+
|
| 789 |
+
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
model (`transformers.PreTrainedModel`):
|
| 793 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
| 794 |
+
ref_model (`PreTrainedModelWrapper`):
|
| 795 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 796 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 797 |
+
reward_model (`transformers.PreTrainedModel`):
|
| 798 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 799 |
+
judge (`BasePairwiseJudge`):
|
| 800 |
+
The judge to use for pairwise comparison of model completions.
|
| 801 |
+
args (`NashMDConfig`):
|
| 802 |
+
The NashMD config arguments to use for training.
|
| 803 |
+
data_collator (`transformers.DataCollator`):
|
| 804 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 805 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 806 |
+
train_dataset (`datasets.Dataset`):
|
| 807 |
+
The dataset to use for training.
|
| 808 |
+
eval_dataset (`datasets.Dataset`):
|
| 809 |
+
The dataset to use for evaluation.
|
| 810 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 811 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 812 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 813 |
+
reuse the fine-tuned model.
|
| 814 |
+
peft_config (`dict`):
|
| 815 |
+
The peft config to use for training.
|
| 816 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 817 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 818 |
+
a dictionary string to metric values.
|
| 819 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 820 |
+
The callbacks to use for training.
|
| 821 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 822 |
+
The optimizer and scheduler to use for training.
|
| 823 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 824 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 825 |
+
|
| 826 |
+
"""
|
| 827 |
+
def __init__(
|
| 828 |
+
self,
|
| 829 |
+
model = None,
|
| 830 |
+
ref_model = None,
|
| 831 |
+
reward_model = None,
|
| 832 |
+
judge = None,
|
| 833 |
+
args = None,
|
| 834 |
+
data_collator = None,
|
| 835 |
+
train_dataset = None,
|
| 836 |
+
eval_dataset = None,
|
| 837 |
+
processing_class = None,
|
| 838 |
+
peft_config = None,
|
| 839 |
+
compute_metrics = None,
|
| 840 |
+
callbacks = None,
|
| 841 |
+
preprocess_logits_for_metrics = None,
|
| 842 |
+
**kwargs
|
| 843 |
+
):
|
| 844 |
+
if args is None: args = UnslothNashMDConfig()
|
| 845 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 846 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 847 |
+
force_float32 = False
|
| 848 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 849 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 850 |
+
force_float32 = True
|
| 851 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 852 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 853 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 854 |
+
from unsloth_zoo.utils import _get_dtype
|
| 855 |
+
dtype = _get_dtype(dtype)
|
| 856 |
+
float16 = dtype == torch.float16
|
| 857 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 858 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 859 |
+
if force_float32:
|
| 860 |
+
args.fp16 = False
|
| 861 |
+
args.bf16 = False
|
| 862 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 863 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 864 |
+
args.fp16 = float16
|
| 865 |
+
args.bf16 = not float16
|
| 866 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 867 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 868 |
+
args.eval_strategy = 'steps'
|
| 869 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 870 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 871 |
+
if ga_steps is not None and ga_steps > 1:
|
| 872 |
+
from transformers import __version__ as transformers_version
|
| 873 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 874 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 875 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 876 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 877 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 878 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 879 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 880 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 881 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 882 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 883 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 884 |
+
if force_float32:
|
| 885 |
+
args.bf16_full_eval = False
|
| 886 |
+
args.fp16_full_eval = False
|
| 887 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 888 |
+
args.bf16_full_eval = True
|
| 889 |
+
args.fp16_full_eval = False
|
| 890 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 891 |
+
args.bf16_full_eval = args.bf16
|
| 892 |
+
args.fp16_full_eval = args.fp16
|
| 893 |
+
_output_logits = False
|
| 894 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 895 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 896 |
+
if _output_logits:
|
| 897 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 898 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 899 |
+
pass
|
| 900 |
+
else:
|
| 901 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 902 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 903 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 904 |
+
max_seq_length = model.max_seq_length
|
| 905 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 906 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 907 |
+
model.for_training()
|
| 908 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 909 |
+
if 'processing_class' in locals():
|
| 910 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 911 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 912 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 913 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 914 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 915 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 916 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 917 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 918 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 919 |
+
else:
|
| 920 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 921 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 922 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 923 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 924 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 925 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 926 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 927 |
+
else:
|
| 928 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 929 |
+
other_metrics = []
|
| 930 |
+
|
| 931 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 932 |
+
PatchRLStatistics('nash_md_trainer', other_metrics)
|
| 933 |
+
|
| 934 |
+
super().__init__(
|
| 935 |
+
model = model,
|
| 936 |
+
ref_model = ref_model,
|
| 937 |
+
reward_model = reward_model,
|
| 938 |
+
judge = judge,
|
| 939 |
+
args = args,
|
| 940 |
+
data_collator = data_collator,
|
| 941 |
+
train_dataset = train_dataset,
|
| 942 |
+
eval_dataset = eval_dataset,
|
| 943 |
+
processing_class = processing_class,
|
| 944 |
+
peft_config = peft_config,
|
| 945 |
+
compute_metrics = compute_metrics,
|
| 946 |
+
callbacks = callbacks,
|
| 947 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 948 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 949 |
+
self.neftune_hook_handle.remove()
|
| 950 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 951 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 952 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 953 |
+
pass
|
| 954 |
+
|
| 955 |
+
pass
|
unsloth_compiled_cache/UnslothORPOTrainer.py
ADDED
|
@@ -0,0 +1,1543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothORPOConfig(ORPOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`ORPOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 55 |
+
[`~transformers.TrainingArguments`].
|
| 56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 58 |
+
to use the default data collator.
|
| 59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 63 |
+
and your model is an encoder-decoder.
|
| 64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 65 |
+
Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
|
| 66 |
+
it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
| 67 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to disable dropout in the model.
|
| 69 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 70 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 71 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 72 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 73 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 74 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 75 |
+
This argument is required if you want to use the default data collator.
|
| 76 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 78 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 79 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 80 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 81 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 82 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 83 |
+
string.
|
| 84 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 85 |
+
Number of processes to use for processing the dataset.
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 89 |
+
default = None,
|
| 90 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 91 |
+
)
|
| 92 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 93 |
+
default = -1,
|
| 94 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 95 |
+
)
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
output_dir = None,
|
| 99 |
+
overwrite_output_dir = None,
|
| 100 |
+
do_train = False,
|
| 101 |
+
do_eval = False,
|
| 102 |
+
do_predict = False,
|
| 103 |
+
eval_strategy = 'no',
|
| 104 |
+
prediction_loss_only = False,
|
| 105 |
+
per_device_train_batch_size = 4,
|
| 106 |
+
per_device_eval_batch_size = 4,
|
| 107 |
+
per_gpu_train_batch_size = None,
|
| 108 |
+
per_gpu_eval_batch_size = None,
|
| 109 |
+
gradient_accumulation_steps = 2,
|
| 110 |
+
eval_accumulation_steps = 2,
|
| 111 |
+
eval_delay = 0,
|
| 112 |
+
torch_empty_cache_steps = 250,
|
| 113 |
+
learning_rate = 5e-05,
|
| 114 |
+
weight_decay = 0.01,
|
| 115 |
+
adam_beta1 = 0.9,
|
| 116 |
+
adam_beta2 = 0.999,
|
| 117 |
+
adam_epsilon = 1e-08,
|
| 118 |
+
max_grad_norm = 1.0,
|
| 119 |
+
num_train_epochs = 3.0,
|
| 120 |
+
max_steps = -1,
|
| 121 |
+
lr_scheduler_type = 'linear',
|
| 122 |
+
warmup_ratio = 0.1,
|
| 123 |
+
warmup_steps = 0,
|
| 124 |
+
log_level = 'passive',
|
| 125 |
+
log_level_replica = 'warning',
|
| 126 |
+
log_on_each_node = True,
|
| 127 |
+
logging_dir = None,
|
| 128 |
+
logging_strategy = 'steps',
|
| 129 |
+
logging_first_step = False,
|
| 130 |
+
logging_steps = 1,
|
| 131 |
+
logging_nan_inf_filter = False,
|
| 132 |
+
save_strategy = 'steps',
|
| 133 |
+
save_steps = 500,
|
| 134 |
+
save_total_limit = None,
|
| 135 |
+
save_safetensors = True,
|
| 136 |
+
save_on_each_node = False,
|
| 137 |
+
save_only_model = False,
|
| 138 |
+
restore_callback_states_from_checkpoint = False,
|
| 139 |
+
no_cuda = False,
|
| 140 |
+
use_cpu = False,
|
| 141 |
+
use_mps_device = False,
|
| 142 |
+
seed = 3407,
|
| 143 |
+
data_seed = 3407,
|
| 144 |
+
jit_mode_eval = False,
|
| 145 |
+
use_ipex = False,
|
| 146 |
+
bf16 = False,
|
| 147 |
+
fp16 = False,
|
| 148 |
+
fp16_opt_level = 'O1',
|
| 149 |
+
half_precision_backend = 'auto',
|
| 150 |
+
bf16_full_eval = False,
|
| 151 |
+
fp16_full_eval = False,
|
| 152 |
+
tf32 = None,
|
| 153 |
+
local_rank = -1,
|
| 154 |
+
ddp_backend = None,
|
| 155 |
+
tpu_num_cores = None,
|
| 156 |
+
tpu_metrics_debug = False,
|
| 157 |
+
debug = '',
|
| 158 |
+
dataloader_drop_last = False,
|
| 159 |
+
eval_steps = None,
|
| 160 |
+
dataloader_num_workers = 0,
|
| 161 |
+
dataloader_prefetch_factor = None,
|
| 162 |
+
past_index = -1,
|
| 163 |
+
run_name = None,
|
| 164 |
+
disable_tqdm = None,
|
| 165 |
+
remove_unused_columns = True,
|
| 166 |
+
label_names = None,
|
| 167 |
+
load_best_model_at_end = False,
|
| 168 |
+
metric_for_best_model = None,
|
| 169 |
+
greater_is_better = None,
|
| 170 |
+
ignore_data_skip = False,
|
| 171 |
+
fsdp = '',
|
| 172 |
+
fsdp_min_num_params = 0,
|
| 173 |
+
fsdp_config = None,
|
| 174 |
+
tp_size = 0,
|
| 175 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 176 |
+
accelerator_config = None,
|
| 177 |
+
deepspeed = None,
|
| 178 |
+
label_smoothing_factor = 0.0,
|
| 179 |
+
optim = 'adamw_8bit',
|
| 180 |
+
optim_args = None,
|
| 181 |
+
adafactor = False,
|
| 182 |
+
group_by_length = False,
|
| 183 |
+
length_column_name = 'length',
|
| 184 |
+
report_to = None,
|
| 185 |
+
ddp_find_unused_parameters = None,
|
| 186 |
+
ddp_bucket_cap_mb = None,
|
| 187 |
+
ddp_broadcast_buffers = None,
|
| 188 |
+
dataloader_pin_memory = True,
|
| 189 |
+
dataloader_persistent_workers = False,
|
| 190 |
+
skip_memory_metrics = True,
|
| 191 |
+
use_legacy_prediction_loop = False,
|
| 192 |
+
push_to_hub = False,
|
| 193 |
+
resume_from_checkpoint = None,
|
| 194 |
+
hub_model_id = None,
|
| 195 |
+
hub_strategy = 'every_save',
|
| 196 |
+
hub_token = None,
|
| 197 |
+
hub_private_repo = None,
|
| 198 |
+
hub_always_push = False,
|
| 199 |
+
gradient_checkpointing = False,
|
| 200 |
+
gradient_checkpointing_kwargs = None,
|
| 201 |
+
include_inputs_for_metrics = False,
|
| 202 |
+
eval_do_concat_batches = True,
|
| 203 |
+
fp16_backend = 'auto',
|
| 204 |
+
evaluation_strategy = None,
|
| 205 |
+
push_to_hub_model_id = None,
|
| 206 |
+
push_to_hub_organization = None,
|
| 207 |
+
push_to_hub_token = None,
|
| 208 |
+
mp_parameters = '',
|
| 209 |
+
auto_find_batch_size = False,
|
| 210 |
+
full_determinism = False,
|
| 211 |
+
torchdynamo = None,
|
| 212 |
+
ray_scope = 'last',
|
| 213 |
+
ddp_timeout = 1800,
|
| 214 |
+
torch_compile = False,
|
| 215 |
+
torch_compile_backend = None,
|
| 216 |
+
torch_compile_mode = None,
|
| 217 |
+
dispatch_batches = None,
|
| 218 |
+
split_batches = None,
|
| 219 |
+
include_tokens_per_second = False,
|
| 220 |
+
include_num_input_tokens_seen = False,
|
| 221 |
+
neftune_noise_alpha = None,
|
| 222 |
+
optim_target_modules = None,
|
| 223 |
+
batch_eval_metrics = False,
|
| 224 |
+
eval_on_start = False,
|
| 225 |
+
use_liger_kernel = False,
|
| 226 |
+
eval_use_gather_object = False,
|
| 227 |
+
average_tokens_across_devices = False,
|
| 228 |
+
max_length = 1024,
|
| 229 |
+
max_prompt_length = 512,
|
| 230 |
+
max_completion_length = None,
|
| 231 |
+
beta = 0.1,
|
| 232 |
+
disable_dropout = True,
|
| 233 |
+
label_pad_token_id = -100,
|
| 234 |
+
padding_value = None,
|
| 235 |
+
truncation_mode = 'keep_end',
|
| 236 |
+
generate_during_eval = False,
|
| 237 |
+
is_encoder_decoder = None,
|
| 238 |
+
model_init_kwargs = None,
|
| 239 |
+
dataset_num_proc = None,
|
| 240 |
+
vllm_sampling_params = None,
|
| 241 |
+
unsloth_num_chunks = -1,
|
| 242 |
+
**kwargs,
|
| 243 |
+
):
|
| 244 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 245 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 246 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 247 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 248 |
+
save_strategy = 'no'
|
| 249 |
+
if dataset_num_proc is None:
|
| 250 |
+
from multiprocessing import cpu_count
|
| 251 |
+
dataset_num_proc = cpu_count()
|
| 252 |
+
|
| 253 |
+
super().__init__(
|
| 254 |
+
output_dir = output_dir,
|
| 255 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 256 |
+
do_train = do_train,
|
| 257 |
+
do_eval = do_eval,
|
| 258 |
+
do_predict = do_predict,
|
| 259 |
+
eval_strategy = eval_strategy,
|
| 260 |
+
prediction_loss_only = prediction_loss_only,
|
| 261 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 262 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 263 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 264 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 265 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 266 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 267 |
+
eval_delay = eval_delay,
|
| 268 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 269 |
+
learning_rate = learning_rate,
|
| 270 |
+
weight_decay = weight_decay,
|
| 271 |
+
adam_beta1 = adam_beta1,
|
| 272 |
+
adam_beta2 = adam_beta2,
|
| 273 |
+
adam_epsilon = adam_epsilon,
|
| 274 |
+
max_grad_norm = max_grad_norm,
|
| 275 |
+
num_train_epochs = num_train_epochs,
|
| 276 |
+
max_steps = max_steps,
|
| 277 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 278 |
+
warmup_ratio = warmup_ratio,
|
| 279 |
+
warmup_steps = warmup_steps,
|
| 280 |
+
log_level = log_level,
|
| 281 |
+
log_level_replica = log_level_replica,
|
| 282 |
+
log_on_each_node = log_on_each_node,
|
| 283 |
+
logging_dir = logging_dir,
|
| 284 |
+
logging_strategy = logging_strategy,
|
| 285 |
+
logging_first_step = logging_first_step,
|
| 286 |
+
logging_steps = logging_steps,
|
| 287 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 288 |
+
save_strategy = save_strategy,
|
| 289 |
+
save_steps = save_steps,
|
| 290 |
+
save_total_limit = save_total_limit,
|
| 291 |
+
save_safetensors = save_safetensors,
|
| 292 |
+
save_on_each_node = save_on_each_node,
|
| 293 |
+
save_only_model = save_only_model,
|
| 294 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 295 |
+
no_cuda = no_cuda,
|
| 296 |
+
use_cpu = use_cpu,
|
| 297 |
+
use_mps_device = use_mps_device,
|
| 298 |
+
seed = seed,
|
| 299 |
+
data_seed = data_seed,
|
| 300 |
+
jit_mode_eval = jit_mode_eval,
|
| 301 |
+
use_ipex = use_ipex,
|
| 302 |
+
bf16 = bf16,
|
| 303 |
+
fp16 = fp16,
|
| 304 |
+
fp16_opt_level = fp16_opt_level,
|
| 305 |
+
half_precision_backend = half_precision_backend,
|
| 306 |
+
bf16_full_eval = bf16_full_eval,
|
| 307 |
+
fp16_full_eval = fp16_full_eval,
|
| 308 |
+
tf32 = tf32,
|
| 309 |
+
local_rank = local_rank,
|
| 310 |
+
ddp_backend = ddp_backend,
|
| 311 |
+
tpu_num_cores = tpu_num_cores,
|
| 312 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 313 |
+
debug = debug,
|
| 314 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 315 |
+
eval_steps = eval_steps,
|
| 316 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 317 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 318 |
+
past_index = past_index,
|
| 319 |
+
run_name = run_name,
|
| 320 |
+
disable_tqdm = disable_tqdm,
|
| 321 |
+
remove_unused_columns = remove_unused_columns,
|
| 322 |
+
label_names = label_names,
|
| 323 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 324 |
+
metric_for_best_model = metric_for_best_model,
|
| 325 |
+
greater_is_better = greater_is_better,
|
| 326 |
+
ignore_data_skip = ignore_data_skip,
|
| 327 |
+
fsdp = fsdp,
|
| 328 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 329 |
+
fsdp_config = fsdp_config,
|
| 330 |
+
tp_size = tp_size,
|
| 331 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 332 |
+
accelerator_config = accelerator_config,
|
| 333 |
+
deepspeed = deepspeed,
|
| 334 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 335 |
+
optim = optim,
|
| 336 |
+
optim_args = optim_args,
|
| 337 |
+
adafactor = adafactor,
|
| 338 |
+
group_by_length = group_by_length,
|
| 339 |
+
length_column_name = length_column_name,
|
| 340 |
+
report_to = report_to,
|
| 341 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 342 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 343 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 344 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 345 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 346 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 347 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 348 |
+
push_to_hub = push_to_hub,
|
| 349 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 350 |
+
hub_model_id = hub_model_id,
|
| 351 |
+
hub_strategy = hub_strategy,
|
| 352 |
+
hub_token = hub_token,
|
| 353 |
+
hub_private_repo = hub_private_repo,
|
| 354 |
+
hub_always_push = hub_always_push,
|
| 355 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 356 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 357 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 358 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 359 |
+
fp16_backend = fp16_backend,
|
| 360 |
+
evaluation_strategy = evaluation_strategy,
|
| 361 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 362 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 363 |
+
push_to_hub_token = push_to_hub_token,
|
| 364 |
+
mp_parameters = mp_parameters,
|
| 365 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 366 |
+
full_determinism = full_determinism,
|
| 367 |
+
torchdynamo = torchdynamo,
|
| 368 |
+
ray_scope = ray_scope,
|
| 369 |
+
ddp_timeout = ddp_timeout,
|
| 370 |
+
torch_compile = torch_compile,
|
| 371 |
+
torch_compile_backend = torch_compile_backend,
|
| 372 |
+
torch_compile_mode = torch_compile_mode,
|
| 373 |
+
dispatch_batches = dispatch_batches,
|
| 374 |
+
split_batches = split_batches,
|
| 375 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 376 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 377 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 378 |
+
optim_target_modules = optim_target_modules,
|
| 379 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 380 |
+
eval_on_start = eval_on_start,
|
| 381 |
+
use_liger_kernel = use_liger_kernel,
|
| 382 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 383 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 384 |
+
max_length = max_length,
|
| 385 |
+
max_prompt_length = max_prompt_length,
|
| 386 |
+
max_completion_length = max_completion_length,
|
| 387 |
+
beta = beta,
|
| 388 |
+
disable_dropout = disable_dropout,
|
| 389 |
+
label_pad_token_id = label_pad_token_id,
|
| 390 |
+
padding_value = padding_value,
|
| 391 |
+
truncation_mode = truncation_mode,
|
| 392 |
+
generate_during_eval = generate_during_eval,
|
| 393 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 394 |
+
model_init_kwargs = model_init_kwargs,
|
| 395 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 396 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 397 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
class _UnslothORPOTrainer(Trainer):
|
| 401 |
+
r""""""
|
| 402 |
+
|
| 403 |
+
_tag_names = ["trl", "orpo"]
|
| 404 |
+
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 408 |
+
args: Optional[ORPOConfig] = None,
|
| 409 |
+
data_collator: Optional[DataCollator] = None,
|
| 410 |
+
train_dataset: Optional[Dataset] = None,
|
| 411 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 412 |
+
processing_class: Optional[
|
| 413 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 414 |
+
] = None,
|
| 415 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 416 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 417 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 418 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 419 |
+
peft_config: Optional[dict] = None,
|
| 420 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 421 |
+
):
|
| 422 |
+
if args.model_init_kwargs is None:
|
| 423 |
+
model_init_kwargs = {}
|
| 424 |
+
elif not isinstance(model, str):
|
| 425 |
+
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
| 426 |
+
else:
|
| 427 |
+
model_init_kwargs = args.model_init_kwargs
|
| 428 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 429 |
+
if torch_dtype is not None:
|
| 430 |
+
# Convert to `torch.dtype` if an str is passed
|
| 431 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 432 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 433 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 436 |
+
)
|
| 437 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 438 |
+
|
| 439 |
+
if isinstance(model, str):
|
| 440 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 441 |
+
|
| 442 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 443 |
+
# has been called in order to properly call autocast if needed.
|
| 444 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 445 |
+
|
| 446 |
+
if not is_peft_available() and peft_config is not None:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 449 |
+
)
|
| 450 |
+
elif is_peft_available() and peft_config is not None:
|
| 451 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 452 |
+
if isinstance(model, PeftModel):
|
| 453 |
+
model = model.merge_and_unload()
|
| 454 |
+
|
| 455 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 456 |
+
_support_gc_kwargs = hasattr(
|
| 457 |
+
args, "gradient_checkpointing_kwargs"
|
| 458 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 459 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 463 |
+
|
| 464 |
+
if _support_gc_kwargs:
|
| 465 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 466 |
+
|
| 467 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 468 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 469 |
+
# For backward compatibility with older versions of transformers
|
| 470 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 471 |
+
model.enable_input_require_grads()
|
| 472 |
+
else:
|
| 473 |
+
|
| 474 |
+
def make_inputs_require_grad(module, input, output):
|
| 475 |
+
output.requires_grad_(True)
|
| 476 |
+
|
| 477 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 478 |
+
|
| 479 |
+
# get peft model with the given config
|
| 480 |
+
model = model
|
| 481 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 482 |
+
peft_module_casting_to_bf16(model)
|
| 483 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 484 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 485 |
+
|
| 486 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 487 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 488 |
+
# fail or completely fail.
|
| 489 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 490 |
+
# For backward compatibility with older versions of transformers
|
| 491 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 492 |
+
model.enable_input_require_grads()
|
| 493 |
+
else:
|
| 494 |
+
|
| 495 |
+
def make_inputs_require_grad(module, input, output):
|
| 496 |
+
output.requires_grad_(True)
|
| 497 |
+
|
| 498 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 499 |
+
|
| 500 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 501 |
+
raise ValueError(
|
| 502 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 503 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if model is not None:
|
| 507 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 508 |
+
elif args.is_encoder_decoder is None:
|
| 509 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 510 |
+
else:
|
| 511 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 512 |
+
|
| 513 |
+
if self.is_encoder_decoder:
|
| 514 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 515 |
+
self.pad_token_id = model.config.pad_token_id
|
| 516 |
+
|
| 517 |
+
if processing_class is None:
|
| 518 |
+
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
| 519 |
+
if args.max_length is None:
|
| 520 |
+
warnings.warn(
|
| 521 |
+
"`max_length` is not set in the ORPOConfig's init"
|
| 522 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 523 |
+
UserWarning,
|
| 524 |
+
)
|
| 525 |
+
max_length = 512
|
| 526 |
+
else:
|
| 527 |
+
max_length = args.max_length
|
| 528 |
+
if args.max_prompt_length is None:
|
| 529 |
+
warnings.warn(
|
| 530 |
+
"`max_prompt_length` is not set in the ORPOConfig's init"
|
| 531 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 532 |
+
UserWarning,
|
| 533 |
+
)
|
| 534 |
+
max_prompt_length = 128
|
| 535 |
+
else:
|
| 536 |
+
max_prompt_length = args.max_prompt_length
|
| 537 |
+
|
| 538 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 539 |
+
warnings.warn(
|
| 540 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
| 541 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 542 |
+
UserWarning,
|
| 543 |
+
)
|
| 544 |
+
self.max_completion_length = 128
|
| 545 |
+
else:
|
| 546 |
+
self.max_completion_length = args.max_completion_length
|
| 547 |
+
|
| 548 |
+
if data_collator is None:
|
| 549 |
+
data_collator = DPODataCollatorWithPadding(
|
| 550 |
+
pad_token_id=processing_class.pad_token_id,
|
| 551 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 552 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
if args.remove_unused_columns:
|
| 556 |
+
args.remove_unused_columns = False
|
| 557 |
+
# warn users
|
| 558 |
+
warnings.warn(
|
| 559 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 560 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 561 |
+
UserWarning,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
self.use_dpo_data_collator = True
|
| 565 |
+
else:
|
| 566 |
+
self.use_dpo_data_collator = False
|
| 567 |
+
|
| 568 |
+
# Disable dropout in the model and reference model
|
| 569 |
+
if args.disable_dropout:
|
| 570 |
+
disable_dropout_in_model(model)
|
| 571 |
+
|
| 572 |
+
self.max_length = max_length
|
| 573 |
+
self.generate_during_eval = args.generate_during_eval
|
| 574 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 575 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 576 |
+
self.max_prompt_length = max_prompt_length
|
| 577 |
+
self.truncation_mode = args.truncation_mode
|
| 578 |
+
self.processing_class = processing_class
|
| 579 |
+
|
| 580 |
+
self.beta = args.beta
|
| 581 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 582 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 583 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 584 |
+
warnings.warn(
|
| 585 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 586 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 587 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 588 |
+
"loss.",
|
| 589 |
+
UserWarning,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 593 |
+
|
| 594 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 595 |
+
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
| 596 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 597 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 598 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 599 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 600 |
+
# that the warning has already been issued.
|
| 601 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 602 |
+
|
| 603 |
+
# Compute that only on the main process for faster data processing.
|
| 604 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 605 |
+
with PartialState().local_main_process_first():
|
| 606 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
| 607 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 608 |
+
train_dataset = train_dataset.map(
|
| 609 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 610 |
+
)
|
| 611 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 612 |
+
if eval_dataset is not None:
|
| 613 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 614 |
+
eval_dataset = eval_dataset.map(
|
| 615 |
+
maybe_apply_chat_template,
|
| 616 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 617 |
+
num_proc=args.dataset_num_proc,
|
| 618 |
+
)
|
| 619 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 620 |
+
|
| 621 |
+
super().__init__(
|
| 622 |
+
model=model,
|
| 623 |
+
args=args,
|
| 624 |
+
data_collator=data_collator,
|
| 625 |
+
train_dataset=train_dataset,
|
| 626 |
+
eval_dataset=eval_dataset,
|
| 627 |
+
processing_class=processing_class,
|
| 628 |
+
model_init=model_init,
|
| 629 |
+
compute_metrics=compute_metrics,
|
| 630 |
+
callbacks=callbacks,
|
| 631 |
+
optimizers=optimizers,
|
| 632 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 636 |
+
if hasattr(self.model, "add_model_tags"):
|
| 637 |
+
self.model.add_model_tags(self._tag_names)
|
| 638 |
+
|
| 639 |
+
if not hasattr(self, "accelerator"):
|
| 640 |
+
raise AttributeError(
|
| 641 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 645 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 646 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 647 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 648 |
+
|
| 649 |
+
if model is not None:
|
| 650 |
+
if hasattr(model, "config"):
|
| 651 |
+
hidden_size = (
|
| 652 |
+
max(model.config.hidden_sizes)
|
| 653 |
+
if getattr(model.config, "hidden_sizes", None)
|
| 654 |
+
else getattr(model.config, "hidden_size", None)
|
| 655 |
+
)
|
| 656 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 657 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 658 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 659 |
+
config_kwargs.update(
|
| 660 |
+
{
|
| 661 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 662 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 663 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 664 |
+
}
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 668 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 669 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 670 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
| 671 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 672 |
+
model.eval()
|
| 673 |
+
return model
|
| 674 |
+
|
| 675 |
+
def build_tokenized_answer(self, prompt, answer):
|
| 676 |
+
"""
|
| 677 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
| 678 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
| 679 |
+
Reference:
|
| 680 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 681 |
+
"""
|
| 682 |
+
|
| 683 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 684 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 685 |
+
|
| 686 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 687 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 688 |
+
|
| 689 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 690 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 691 |
+
|
| 692 |
+
# Prepare input tokens for token by token comparison
|
| 693 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 694 |
+
|
| 695 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
| 696 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 697 |
+
|
| 698 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 699 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 700 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 701 |
+
# vs when done as prompt+answer.
|
| 702 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
| 703 |
+
|
| 704 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 705 |
+
# last token has changed due to merging.
|
| 706 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 707 |
+
response_token_ids_start_idx -= 1
|
| 708 |
+
|
| 709 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 710 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 711 |
+
|
| 712 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 713 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 714 |
+
|
| 715 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 716 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 717 |
+
|
| 718 |
+
return dict(
|
| 719 |
+
prompt_input_ids=prompt_input_ids,
|
| 720 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 721 |
+
input_ids=answer_input_ids,
|
| 722 |
+
attention_mask=answer_attention_mask,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 726 |
+
"""Tokenize a single row from a ORPO specific dataset.
|
| 727 |
+
|
| 728 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
| 729 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
| 730 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
| 731 |
+
|
| 732 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
| 733 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
| 734 |
+
label_pad_token_id for the prompt tokens.
|
| 735 |
+
"""
|
| 736 |
+
batch = {}
|
| 737 |
+
prompt = feature["prompt"]
|
| 738 |
+
chosen = feature["chosen"]
|
| 739 |
+
rejected = feature["rejected"]
|
| 740 |
+
|
| 741 |
+
if not self.is_encoder_decoder:
|
| 742 |
+
# Check issues below for more details
|
| 743 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 744 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 745 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 746 |
+
|
| 747 |
+
if not isinstance(prompt, str):
|
| 748 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 749 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 750 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 751 |
+
|
| 752 |
+
if not isinstance(chosen, str):
|
| 753 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 754 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 755 |
+
|
| 756 |
+
if not isinstance(rejected, str):
|
| 757 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 758 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 759 |
+
|
| 760 |
+
# Last prompt token might get merged by tokenizer and
|
| 761 |
+
# it should not be included for generation if that happens
|
| 762 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 763 |
+
|
| 764 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 765 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 766 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 767 |
+
|
| 768 |
+
for k, v in prompt_tokens.items():
|
| 769 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 770 |
+
|
| 771 |
+
# Make sure prompts only have one different token at most an
|
| 772 |
+
# and length only differs by 1 at most
|
| 773 |
+
num_diff_tokens = sum(
|
| 774 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
| 775 |
+
)
|
| 776 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 777 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 778 |
+
raise ValueError(
|
| 779 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 780 |
+
"last token due to tokenizer merge ops."
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 784 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 785 |
+
self.processing_class.bos_token_id,
|
| 786 |
+
prompt_len_input_ids,
|
| 787 |
+
prompt_tokens,
|
| 788 |
+
chosen_prompt_len_input_ids,
|
| 789 |
+
chosen_tokens,
|
| 790 |
+
rejected_prompt_len_input_ids,
|
| 791 |
+
rejected_tokens,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
| 795 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 796 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 800 |
+
|
| 801 |
+
# if combined sequence is too long, truncate the prompt
|
| 802 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 803 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 804 |
+
if self.truncation_mode == "keep_start":
|
| 805 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 806 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 807 |
+
elif self.truncation_mode == "keep_end":
|
| 808 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 809 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 810 |
+
else:
|
| 811 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 812 |
+
|
| 813 |
+
# if that's still too long, truncate the response
|
| 814 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 815 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 816 |
+
for k in ["input_ids", "attention_mask"]:
|
| 817 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 818 |
+
|
| 819 |
+
# Create labels
|
| 820 |
+
chosen_sequence_tokens = {
|
| 821 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 822 |
+
}
|
| 823 |
+
rejected_sequence_tokens = {
|
| 824 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 825 |
+
}
|
| 826 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 827 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 828 |
+
self.label_pad_token_id
|
| 829 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
| 830 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 831 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 832 |
+
self.label_pad_token_id
|
| 833 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
| 834 |
+
|
| 835 |
+
for k, toks in {
|
| 836 |
+
"chosen_": chosen_sequence_tokens,
|
| 837 |
+
"rejected_": rejected_sequence_tokens,
|
| 838 |
+
"": prompt_tokens,
|
| 839 |
+
}.items():
|
| 840 |
+
for type_key, tokens in toks.items():
|
| 841 |
+
if type_key == "token_type_ids":
|
| 842 |
+
continue
|
| 843 |
+
batch[f"{k}{type_key}"] = tokens
|
| 844 |
+
|
| 845 |
+
else:
|
| 846 |
+
chosen_tokens = self.processing_class(
|
| 847 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 848 |
+
)
|
| 849 |
+
rejected_tokens = self.processing_class(
|
| 850 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 851 |
+
)
|
| 852 |
+
prompt_tokens = self.processing_class(
|
| 853 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 857 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 858 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 859 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 860 |
+
|
| 861 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 862 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 863 |
+
labels=torch.tensor(batch["rejected_labels"])
|
| 864 |
+
)
|
| 865 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 866 |
+
labels=torch.tensor(batch["chosen_labels"])
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
if is_torch_xla_available():
|
| 870 |
+
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
| 871 |
+
for k in batch:
|
| 872 |
+
if "labels" in k or self.is_encoder_decoder:
|
| 873 |
+
pad_value = self.label_pad_token_id
|
| 874 |
+
elif k.endswith("_input_ids"):
|
| 875 |
+
pad_value = self.padding_value
|
| 876 |
+
elif k.endswith("_attention_mask"):
|
| 877 |
+
pad_value = 0
|
| 878 |
+
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
| 879 |
+
return batch
|
| 880 |
+
|
| 881 |
+
@staticmethod
|
| 882 |
+
def concatenated_inputs(
|
| 883 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 884 |
+
is_encoder_decoder: bool = False,
|
| 885 |
+
label_pad_token_id: int = -100,
|
| 886 |
+
padding_value: int = 0,
|
| 887 |
+
device: Optional[torch.device] = None,
|
| 888 |
+
) -> dict[str, torch.LongTensor]:
|
| 889 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 890 |
+
|
| 891 |
+
Args:
|
| 892 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
| 893 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 894 |
+
label_pad_token_id: The label pad token id.
|
| 895 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
| 896 |
+
device: The device for the concatenated inputs.
|
| 897 |
+
|
| 898 |
+
Returns:
|
| 899 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 900 |
+
"""
|
| 901 |
+
concatenated_batch = {}
|
| 902 |
+
|
| 903 |
+
if is_encoder_decoder:
|
| 904 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 905 |
+
else:
|
| 906 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 907 |
+
|
| 908 |
+
for k in batch:
|
| 909 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 910 |
+
if "labels" in k or is_encoder_decoder:
|
| 911 |
+
pad_value = label_pad_token_id
|
| 912 |
+
elif k.endswith("_input_ids"):
|
| 913 |
+
pad_value = padding_value
|
| 914 |
+
elif k.endswith("_attention_mask"):
|
| 915 |
+
pad_value = 0
|
| 916 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
| 917 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 918 |
+
for k in batch:
|
| 919 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 920 |
+
if "labels" in k or is_encoder_decoder:
|
| 921 |
+
pad_value = label_pad_token_id
|
| 922 |
+
elif k.endswith("_input_ids"):
|
| 923 |
+
pad_value = padding_value
|
| 924 |
+
elif k.endswith("_attention_mask"):
|
| 925 |
+
pad_value = 0
|
| 926 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
| 927 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
| 928 |
+
(
|
| 929 |
+
concatenated_batch[concatenated_key],
|
| 930 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 931 |
+
),
|
| 932 |
+
dim=0,
|
| 933 |
+
).to(device=device)
|
| 934 |
+
|
| 935 |
+
if is_encoder_decoder:
|
| 936 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 937 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
| 938 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
return concatenated_batch
|
| 942 |
+
|
| 943 |
+
def odds_ratio_loss(
|
| 944 |
+
self,
|
| 945 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 946 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 947 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 948 |
+
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
| 949 |
+
|
| 950 |
+
Args:
|
| 951 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 952 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 953 |
+
|
| 954 |
+
Returns:
|
| 955 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
| 956 |
+
The losses tensor contains the ORPO loss for each example in the batch.
|
| 957 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 958 |
+
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
| 959 |
+
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
| 963 |
+
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
| 964 |
+
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
| 965 |
+
)
|
| 966 |
+
ratio = F.logsigmoid(log_odds)
|
| 967 |
+
losses = self.beta * ratio
|
| 968 |
+
|
| 969 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 970 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 971 |
+
|
| 972 |
+
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
| 973 |
+
|
| 974 |
+
@staticmethod
|
| 975 |
+
def get_batch_logps(
|
| 976 |
+
logits: torch.FloatTensor,
|
| 977 |
+
labels: torch.LongTensor,
|
| 978 |
+
average_log_prob: bool = False,
|
| 979 |
+
label_pad_token_id: int = -100,
|
| 980 |
+
is_encoder_decoder: bool = False,
|
| 981 |
+
) -> torch.FloatTensor:
|
| 982 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 986 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 987 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 988 |
+
label_pad_token_id: The label pad token id.
|
| 989 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 990 |
+
|
| 991 |
+
Returns:
|
| 992 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 993 |
+
"""
|
| 994 |
+
if logits.shape[:-1] != labels.shape:
|
| 995 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 996 |
+
|
| 997 |
+
if not is_encoder_decoder:
|
| 998 |
+
labels = labels[:, 1:].clone()
|
| 999 |
+
logits = logits[:, :-1, :]
|
| 1000 |
+
loss_mask = labels != label_pad_token_id
|
| 1001 |
+
|
| 1002 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1003 |
+
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
| 1004 |
+
|
| 1005 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1006 |
+
|
| 1007 |
+
if average_log_prob:
|
| 1008 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1009 |
+
else:
|
| 1010 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1011 |
+
|
| 1012 |
+
def concatenated_forward(
|
| 1013 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1014 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1015 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1016 |
+
|
| 1017 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1018 |
+
"""
|
| 1019 |
+
concatenated_batch = self.concatenated_inputs(
|
| 1020 |
+
batch,
|
| 1021 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1022 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1023 |
+
padding_value=self.padding_value,
|
| 1024 |
+
device=self.accelerator.device,
|
| 1025 |
+
)
|
| 1026 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
| 1027 |
+
|
| 1028 |
+
model_kwargs = (
|
| 1029 |
+
{
|
| 1030 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1031 |
+
}
|
| 1032 |
+
if self.is_encoder_decoder
|
| 1033 |
+
else {}
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
if self.aux_loss_enabled:
|
| 1037 |
+
model_kwargs["output_router_logits"] = True
|
| 1038 |
+
|
| 1039 |
+
outputs = model(
|
| 1040 |
+
concatenated_batch["concatenated_input_ids"],
|
| 1041 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1042 |
+
use_cache=False,
|
| 1043 |
+
**model_kwargs,
|
| 1044 |
+
)
|
| 1045 |
+
all_logits = outputs.logits
|
| 1046 |
+
|
| 1047 |
+
def cross_entropy_loss(logits, labels):
|
| 1048 |
+
if not self.is_encoder_decoder:
|
| 1049 |
+
# Shift so that tokens < n predict n
|
| 1050 |
+
logits = logits[..., :-1, :].contiguous()
|
| 1051 |
+
labels = labels[..., 1:].contiguous()
|
| 1052 |
+
# Flatten the tokens
|
| 1053 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1054 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 1055 |
+
labels = labels.view(-1)
|
| 1056 |
+
# Enable model parallelism
|
| 1057 |
+
labels = labels.to(logits.device)
|
| 1058 |
+
loss = loss_fct(logits, labels)
|
| 1059 |
+
return loss
|
| 1060 |
+
|
| 1061 |
+
if self.is_encoder_decoder:
|
| 1062 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1063 |
+
else:
|
| 1064 |
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
| 1065 |
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
| 1066 |
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
| 1067 |
+
# orpo chosen nll loss is computed over the full prompt and response
|
| 1068 |
+
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1069 |
+
|
| 1070 |
+
all_logps = self.get_batch_logps(
|
| 1071 |
+
all_logits,
|
| 1072 |
+
concatenated_batch["concatenated_labels"],
|
| 1073 |
+
average_log_prob=True,
|
| 1074 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1075 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
chosen_logps = all_logps[:len_chosen]
|
| 1079 |
+
rejected_logps = all_logps[len_chosen:]
|
| 1080 |
+
|
| 1081 |
+
if not self.is_encoder_decoder:
|
| 1082 |
+
chosen_logits = all_logits[:len_chosen, :-1, :]
|
| 1083 |
+
rejected_logits = all_logits[len_chosen:, :-1, :]
|
| 1084 |
+
else:
|
| 1085 |
+
chosen_logits = all_logits[:len_chosen]
|
| 1086 |
+
rejected_logits = all_logits[len_chosen:]
|
| 1087 |
+
|
| 1088 |
+
if self.aux_loss_enabled:
|
| 1089 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
| 1090 |
+
|
| 1091 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
| 1092 |
+
|
| 1093 |
+
def get_batch_loss_metrics(
|
| 1094 |
+
self,
|
| 1095 |
+
model,
|
| 1096 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1097 |
+
train_eval: Literal["train", "eval"] = "train",
|
| 1098 |
+
):
|
| 1099 |
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1100 |
+
metrics = {}
|
| 1101 |
+
|
| 1102 |
+
forward_output = self.concatenated_forward(model, batch)
|
| 1103 |
+
(
|
| 1104 |
+
policy_chosen_logps,
|
| 1105 |
+
policy_rejected_logps,
|
| 1106 |
+
policy_chosen_logits,
|
| 1107 |
+
policy_rejected_logits,
|
| 1108 |
+
policy_nll_loss,
|
| 1109 |
+
) = forward_output[:5]
|
| 1110 |
+
if self.aux_loss_enabled:
|
| 1111 |
+
aux_loss = forward_output[5]
|
| 1112 |
+
|
| 1113 |
+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
| 1114 |
+
policy_chosen_logps, policy_rejected_logps
|
| 1115 |
+
)
|
| 1116 |
+
# full ORPO loss
|
| 1117 |
+
loss = policy_nll_loss - losses.mean()
|
| 1118 |
+
|
| 1119 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1120 |
+
|
| 1121 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1122 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
| 1123 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
| 1124 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
| 1125 |
+
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
| 1126 |
+
chosen_rewards - rejected_rewards
|
| 1127 |
+
).mean()
|
| 1128 |
+
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
| 1129 |
+
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
| 1130 |
+
metrics[f"{prefix}logits/rejected"] = (
|
| 1131 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
|
| 1132 |
+
)
|
| 1133 |
+
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
|
| 1134 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
| 1135 |
+
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
|
| 1136 |
+
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
|
| 1137 |
+
if is_torch_xla_available():
|
| 1138 |
+
xm.mark_step() # needed because .item() calls
|
| 1139 |
+
for k, v in metrics.items():
|
| 1140 |
+
metrics[k] = v.item()
|
| 1141 |
+
if self.aux_loss_enabled:
|
| 1142 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1143 |
+
|
| 1144 |
+
return loss, metrics
|
| 1145 |
+
|
| 1146 |
+
def compute_loss(
|
| 1147 |
+
self,
|
| 1148 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1149 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1150 |
+
return_outputs=False,
|
| 1151 |
+
num_items_in_batch=None,
|
| 1152 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1153 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1154 |
+
|
| 1155 |
+
with compute_loss_context_manager:
|
| 1156 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1157 |
+
|
| 1158 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1159 |
+
loss = loss.to(self.args.device)
|
| 1160 |
+
|
| 1161 |
+
# force log the metrics
|
| 1162 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1163 |
+
|
| 1164 |
+
if return_outputs:
|
| 1165 |
+
return (loss, metrics)
|
| 1166 |
+
return loss
|
| 1167 |
+
|
| 1168 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1169 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1170 |
+
|
| 1171 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1172 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1173 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1174 |
+
|
| 1175 |
+
with generate_context_manager:
|
| 1176 |
+
policy_output = model.generate(
|
| 1177 |
+
input_ids=batch["prompt_input_ids"],
|
| 1178 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1179 |
+
max_length=self.max_length,
|
| 1180 |
+
do_sample=True,
|
| 1181 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1185 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1186 |
+
|
| 1187 |
+
return policy_output_decoded
|
| 1188 |
+
|
| 1189 |
+
def prediction_step(
|
| 1190 |
+
self,
|
| 1191 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1192 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1193 |
+
prediction_loss_only: bool,
|
| 1194 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1195 |
+
):
|
| 1196 |
+
if not self.use_dpo_data_collator:
|
| 1197 |
+
warnings.warn(
|
| 1198 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
| 1199 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
| 1200 |
+
)
|
| 1201 |
+
if ignore_keys is None:
|
| 1202 |
+
if hasattr(model, "config"):
|
| 1203 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1204 |
+
else:
|
| 1205 |
+
ignore_keys = []
|
| 1206 |
+
|
| 1207 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1208 |
+
|
| 1209 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1210 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1211 |
+
|
| 1212 |
+
# force log the metrics
|
| 1213 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1214 |
+
|
| 1215 |
+
if prediction_loss_only:
|
| 1216 |
+
return (loss.detach(), None, None)
|
| 1217 |
+
|
| 1218 |
+
# logits for the chosen and rejected samples from model
|
| 1219 |
+
logits_dict = {
|
| 1220 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1221 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1222 |
+
}
|
| 1223 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
| 1224 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
| 1225 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1226 |
+
|
| 1227 |
+
return (loss.detach(), logits, labels)
|
| 1228 |
+
|
| 1229 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1230 |
+
for key, value in metrics.items():
|
| 1231 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1232 |
+
|
| 1233 |
+
def evaluation_loop(
|
| 1234 |
+
self,
|
| 1235 |
+
dataloader: DataLoader,
|
| 1236 |
+
description: str,
|
| 1237 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1238 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1239 |
+
metric_key_prefix: str = "eval",
|
| 1240 |
+
) -> EvalLoopOutput:
|
| 1241 |
+
"""
|
| 1242 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1243 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1244 |
+
|
| 1245 |
+
Works both with or without labels.
|
| 1246 |
+
"""
|
| 1247 |
+
|
| 1248 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1249 |
+
if self.generate_during_eval:
|
| 1250 |
+
# Generate random indices within the range of the total number of samples
|
| 1251 |
+
num_samples = len(dataloader.dataset)
|
| 1252 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1253 |
+
|
| 1254 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1255 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1256 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1257 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1258 |
+
|
| 1259 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1260 |
+
|
| 1261 |
+
table = pd.DataFrame(
|
| 1262 |
+
columns=["Prompt", "Policy"],
|
| 1263 |
+
data=[
|
| 1264 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1265 |
+
],
|
| 1266 |
+
)
|
| 1267 |
+
if "wandb" in self.args.report_to:
|
| 1268 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1269 |
+
|
| 1270 |
+
if "comet_ml" in self.args.report_to:
|
| 1271 |
+
log_table_to_comet_experiment(
|
| 1272 |
+
name="game_log.csv",
|
| 1273 |
+
table=table,
|
| 1274 |
+
)
|
| 1275 |
+
|
| 1276 |
+
# Base evaluation
|
| 1277 |
+
initial_output = super().evaluation_loop(
|
| 1278 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1279 |
+
)
|
| 1280 |
+
|
| 1281 |
+
return initial_output
|
| 1282 |
+
|
| 1283 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1284 |
+
"""
|
| 1285 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1286 |
+
|
| 1287 |
+
Args:
|
| 1288 |
+
logs (`dict[str, float]`):
|
| 1289 |
+
The values to log.
|
| 1290 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1291 |
+
Start time of the training.
|
| 1292 |
+
"""
|
| 1293 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1294 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1295 |
+
# Add averaged stored metrics to logs
|
| 1296 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1297 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 1298 |
+
del self._stored_metrics[train_eval]
|
| 1299 |
+
|
| 1300 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1301 |
+
return super().log(logs, start_time)
|
| 1302 |
+
else: # transformers<=4.46
|
| 1303 |
+
return super().log(logs)
|
| 1304 |
+
|
| 1305 |
+
def _shift_right(self, input_ids):
|
| 1306 |
+
if self.decoder_start_token_id is None:
|
| 1307 |
+
raise ValueError(
|
| 1308 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1309 |
+
)
|
| 1310 |
+
|
| 1311 |
+
# shift inputs to the right
|
| 1312 |
+
if is_torch_fx_proxy(input_ids):
|
| 1313 |
+
# Item assignment is not supported natively for proxies.
|
| 1314 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1315 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1316 |
+
else:
|
| 1317 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1318 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1319 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1320 |
+
|
| 1321 |
+
if self.pad_token_id is None:
|
| 1322 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1323 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1324 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1325 |
+
|
| 1326 |
+
return shifted_input_ids
|
| 1327 |
+
|
| 1328 |
+
def create_model_card(
|
| 1329 |
+
self,
|
| 1330 |
+
model_name: Optional[str] = None,
|
| 1331 |
+
dataset_name: Optional[str] = None,
|
| 1332 |
+
tags: Union[str, list[str], None] = None,
|
| 1333 |
+
):
|
| 1334 |
+
"""
|
| 1335 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1336 |
+
|
| 1337 |
+
Args:
|
| 1338 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1339 |
+
Name of the model.
|
| 1340 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1341 |
+
Name of the dataset used for training.
|
| 1342 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1343 |
+
Tags to be associated with the model card.
|
| 1344 |
+
"""
|
| 1345 |
+
if not self.is_world_process_zero():
|
| 1346 |
+
return
|
| 1347 |
+
|
| 1348 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1349 |
+
base_model = self.model.config._name_or_path
|
| 1350 |
+
else:
|
| 1351 |
+
base_model = None
|
| 1352 |
+
|
| 1353 |
+
tags = tags or []
|
| 1354 |
+
if isinstance(tags, str):
|
| 1355 |
+
tags = [tags]
|
| 1356 |
+
|
| 1357 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1358 |
+
tags.append("unsloth")
|
| 1359 |
+
|
| 1360 |
+
citation = textwrap.dedent("""\
|
| 1361 |
+
@article{hong2024orpo,
|
| 1362 |
+
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
| 1363 |
+
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
| 1364 |
+
year = 2024,
|
| 1365 |
+
eprint = {arXiv:2403.07691}
|
| 1366 |
+
}""")
|
| 1367 |
+
|
| 1368 |
+
model_card = generate_model_card(
|
| 1369 |
+
base_model=base_model,
|
| 1370 |
+
model_name=model_name,
|
| 1371 |
+
hub_model_id=self.hub_model_id,
|
| 1372 |
+
dataset_name=dataset_name,
|
| 1373 |
+
tags=tags,
|
| 1374 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1375 |
+
comet_url=get_comet_experiment_url(),
|
| 1376 |
+
trainer_name="ORPO",
|
| 1377 |
+
trainer_citation=citation,
|
| 1378 |
+
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
|
| 1379 |
+
paper_id="2403.07691",
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1383 |
+
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
| 1384 |
+
"""
|
| 1385 |
+
|
| 1386 |
+
Initialize ORPOTrainer.
|
| 1387 |
+
|
| 1388 |
+
Args:
|
| 1389 |
+
model (`transformers.PreTrainedModel`):
|
| 1390 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1391 |
+
args (`ORPOConfig`):
|
| 1392 |
+
The ORPO config arguments to use for training.
|
| 1393 |
+
data_collator (`transformers.DataCollator`):
|
| 1394 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1395 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1396 |
+
train_dataset (`datasets.Dataset`):
|
| 1397 |
+
The dataset to use for training.
|
| 1398 |
+
eval_dataset (`datasets.Dataset`):
|
| 1399 |
+
The dataset to use for evaluation.
|
| 1400 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1401 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1402 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1403 |
+
reuse the fine-tuned model.
|
| 1404 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1405 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1406 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1407 |
+
The callbacks to use for training.
|
| 1408 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1409 |
+
The optimizer and scheduler to use for training.
|
| 1410 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1411 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1412 |
+
peft_config (`dict`, defaults to `None`):
|
| 1413 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1414 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1415 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1416 |
+
a dictionary string to metric values.
|
| 1417 |
+
|
| 1418 |
+
"""
|
| 1419 |
+
def __init__(
|
| 1420 |
+
self,
|
| 1421 |
+
model = None,
|
| 1422 |
+
args = None,
|
| 1423 |
+
data_collator = None,
|
| 1424 |
+
train_dataset = None,
|
| 1425 |
+
eval_dataset = None,
|
| 1426 |
+
processing_class = None,
|
| 1427 |
+
model_init = None,
|
| 1428 |
+
callbacks = None,
|
| 1429 |
+
preprocess_logits_for_metrics = None,
|
| 1430 |
+
peft_config = None,
|
| 1431 |
+
compute_metrics = None,
|
| 1432 |
+
**kwargs
|
| 1433 |
+
):
|
| 1434 |
+
if args is None: args = UnslothORPOConfig()
|
| 1435 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1436 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1437 |
+
force_float32 = False
|
| 1438 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1439 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1440 |
+
force_float32 = True
|
| 1441 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1442 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1443 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1444 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1445 |
+
dtype = _get_dtype(dtype)
|
| 1446 |
+
float16 = dtype == torch.float16
|
| 1447 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1448 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1449 |
+
if force_float32:
|
| 1450 |
+
args.fp16 = False
|
| 1451 |
+
args.bf16 = False
|
| 1452 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1453 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1454 |
+
args.fp16 = float16
|
| 1455 |
+
args.bf16 = not float16
|
| 1456 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1457 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1458 |
+
args.eval_strategy = 'steps'
|
| 1459 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1460 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1461 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1462 |
+
from transformers import __version__ as transformers_version
|
| 1463 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1464 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1465 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1466 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1467 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1468 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1469 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1470 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1471 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1472 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1473 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1474 |
+
if force_float32:
|
| 1475 |
+
args.bf16_full_eval = False
|
| 1476 |
+
args.fp16_full_eval = False
|
| 1477 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1478 |
+
args.bf16_full_eval = True
|
| 1479 |
+
args.fp16_full_eval = False
|
| 1480 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1481 |
+
args.bf16_full_eval = args.bf16
|
| 1482 |
+
args.fp16_full_eval = args.fp16
|
| 1483 |
+
_output_logits = False
|
| 1484 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1485 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1486 |
+
if _output_logits:
|
| 1487 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1488 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1489 |
+
pass
|
| 1490 |
+
else:
|
| 1491 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1492 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1493 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1494 |
+
max_seq_length = model.max_seq_length
|
| 1495 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1496 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1497 |
+
model.for_training()
|
| 1498 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1499 |
+
if 'processing_class' in locals():
|
| 1500 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1501 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1502 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1503 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1504 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1505 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1506 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1507 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1508 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1509 |
+
else:
|
| 1510 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1511 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1512 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1513 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1514 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1515 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1516 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1517 |
+
else:
|
| 1518 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1519 |
+
other_metrics = []
|
| 1520 |
+
|
| 1521 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1522 |
+
PatchRLStatistics('orpo_trainer', other_metrics)
|
| 1523 |
+
|
| 1524 |
+
super().__init__(
|
| 1525 |
+
model = model,
|
| 1526 |
+
args = args,
|
| 1527 |
+
data_collator = data_collator,
|
| 1528 |
+
train_dataset = train_dataset,
|
| 1529 |
+
eval_dataset = eval_dataset,
|
| 1530 |
+
processing_class = processing_class,
|
| 1531 |
+
model_init = model_init,
|
| 1532 |
+
callbacks = callbacks,
|
| 1533 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1534 |
+
peft_config = peft_config,
|
| 1535 |
+
compute_metrics = compute_metrics,**kwargs)
|
| 1536 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1537 |
+
self.neftune_hook_handle.remove()
|
| 1538 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1539 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1540 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1541 |
+
pass
|
| 1542 |
+
|
| 1543 |
+
pass
|
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
ADDED
|
@@ -0,0 +1,1269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
def vLLMSamplingParams(**kwargs):
|
| 43 |
+
from vllm import SamplingParams
|
| 44 |
+
sampling_params = SamplingParams(**kwargs)
|
| 45 |
+
sampling_params._set_kwargs = kwargs
|
| 46 |
+
return sampling_params
|
| 47 |
+
@dataclass
|
| 48 |
+
class UnslothOnlineDPOConfig(OnlineDPOConfig):
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
Configuration class for the [`OnlineDPOTrainer`].
|
| 52 |
+
|
| 53 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 54 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 55 |
+
command line.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
| 59 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 60 |
+
[`~transformers.TrainingArguments`].
|
| 61 |
+
reward_model_path (`str` or `None`, *optional*, defaults to `None`):
|
| 62 |
+
Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
|
| 63 |
+
judge (`str` or `None`, *optional*, defaults to `None`):
|
| 64 |
+
Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
|
| 65 |
+
max_new_tokens (`int`, *optional*, defaults to `64`):
|
| 66 |
+
Maximum number of tokens to generate per completion.
|
| 67 |
+
max_length (`int`, *optional*, defaults to `256`):
|
| 68 |
+
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
|
| 69 |
+
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
|
| 70 |
+
possible.
|
| 71 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 72 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 73 |
+
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
|
| 74 |
+
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
|
| 75 |
+
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
|
| 76 |
+
value.
|
| 77 |
+
beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
|
| 78 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 79 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 80 |
+
the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
|
| 81 |
+
selected for each new epoch and the last β is used for the rest of the epochs.
|
| 82 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 83 |
+
Type of loss to use. Possible values are:
|
| 84 |
+
|
| 85 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 86 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 87 |
+
|
| 88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 89 |
+
Number of processes to use for processing the dataset.
|
| 90 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 91 |
+
Whether to disable dropout in the model and reference model.
|
| 92 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
|
| 94 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 95 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 96 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 97 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 101 |
+
default = None,
|
| 102 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 103 |
+
)
|
| 104 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 105 |
+
default = -1,
|
| 106 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 107 |
+
)
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
output_dir = None,
|
| 111 |
+
overwrite_output_dir = None,
|
| 112 |
+
do_train = False,
|
| 113 |
+
do_eval = False,
|
| 114 |
+
do_predict = False,
|
| 115 |
+
eval_strategy = 'no',
|
| 116 |
+
prediction_loss_only = False,
|
| 117 |
+
per_device_train_batch_size = 4,
|
| 118 |
+
per_device_eval_batch_size = 4,
|
| 119 |
+
per_gpu_train_batch_size = None,
|
| 120 |
+
per_gpu_eval_batch_size = None,
|
| 121 |
+
gradient_accumulation_steps = 2,
|
| 122 |
+
eval_accumulation_steps = 2,
|
| 123 |
+
eval_delay = 0,
|
| 124 |
+
torch_empty_cache_steps = 250,
|
| 125 |
+
learning_rate = 5e-05,
|
| 126 |
+
weight_decay = 0.01,
|
| 127 |
+
adam_beta1 = 0.9,
|
| 128 |
+
adam_beta2 = 0.999,
|
| 129 |
+
adam_epsilon = 1e-08,
|
| 130 |
+
max_grad_norm = 1.0,
|
| 131 |
+
num_train_epochs = 3.0,
|
| 132 |
+
max_steps = -1,
|
| 133 |
+
lr_scheduler_type = 'linear',
|
| 134 |
+
warmup_ratio = 0.1,
|
| 135 |
+
warmup_steps = 0,
|
| 136 |
+
log_level = 'passive',
|
| 137 |
+
log_level_replica = 'warning',
|
| 138 |
+
log_on_each_node = True,
|
| 139 |
+
logging_dir = None,
|
| 140 |
+
logging_strategy = 'steps',
|
| 141 |
+
logging_first_step = False,
|
| 142 |
+
logging_steps = 1,
|
| 143 |
+
logging_nan_inf_filter = False,
|
| 144 |
+
save_strategy = 'steps',
|
| 145 |
+
save_steps = 500,
|
| 146 |
+
save_total_limit = None,
|
| 147 |
+
save_safetensors = True,
|
| 148 |
+
save_on_each_node = False,
|
| 149 |
+
save_only_model = False,
|
| 150 |
+
restore_callback_states_from_checkpoint = False,
|
| 151 |
+
no_cuda = False,
|
| 152 |
+
use_cpu = False,
|
| 153 |
+
use_mps_device = False,
|
| 154 |
+
seed = 3407,
|
| 155 |
+
data_seed = 3407,
|
| 156 |
+
jit_mode_eval = False,
|
| 157 |
+
use_ipex = False,
|
| 158 |
+
bf16 = False,
|
| 159 |
+
fp16 = False,
|
| 160 |
+
fp16_opt_level = 'O1',
|
| 161 |
+
half_precision_backend = 'auto',
|
| 162 |
+
bf16_full_eval = False,
|
| 163 |
+
fp16_full_eval = False,
|
| 164 |
+
tf32 = None,
|
| 165 |
+
local_rank = -1,
|
| 166 |
+
ddp_backend = None,
|
| 167 |
+
tpu_num_cores = None,
|
| 168 |
+
tpu_metrics_debug = False,
|
| 169 |
+
debug = '',
|
| 170 |
+
dataloader_drop_last = False,
|
| 171 |
+
eval_steps = None,
|
| 172 |
+
dataloader_num_workers = 0,
|
| 173 |
+
dataloader_prefetch_factor = None,
|
| 174 |
+
past_index = -1,
|
| 175 |
+
run_name = None,
|
| 176 |
+
disable_tqdm = None,
|
| 177 |
+
remove_unused_columns = True,
|
| 178 |
+
label_names = None,
|
| 179 |
+
load_best_model_at_end = False,
|
| 180 |
+
metric_for_best_model = None,
|
| 181 |
+
greater_is_better = None,
|
| 182 |
+
ignore_data_skip = False,
|
| 183 |
+
fsdp = '',
|
| 184 |
+
fsdp_min_num_params = 0,
|
| 185 |
+
fsdp_config = None,
|
| 186 |
+
tp_size = 0,
|
| 187 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 188 |
+
accelerator_config = None,
|
| 189 |
+
deepspeed = None,
|
| 190 |
+
label_smoothing_factor = 0.0,
|
| 191 |
+
optim = 'adamw_8bit',
|
| 192 |
+
optim_args = None,
|
| 193 |
+
adafactor = False,
|
| 194 |
+
group_by_length = False,
|
| 195 |
+
length_column_name = 'length',
|
| 196 |
+
report_to = None,
|
| 197 |
+
ddp_find_unused_parameters = None,
|
| 198 |
+
ddp_bucket_cap_mb = None,
|
| 199 |
+
ddp_broadcast_buffers = None,
|
| 200 |
+
dataloader_pin_memory = True,
|
| 201 |
+
dataloader_persistent_workers = False,
|
| 202 |
+
skip_memory_metrics = True,
|
| 203 |
+
use_legacy_prediction_loop = False,
|
| 204 |
+
push_to_hub = False,
|
| 205 |
+
resume_from_checkpoint = None,
|
| 206 |
+
hub_model_id = None,
|
| 207 |
+
hub_strategy = 'every_save',
|
| 208 |
+
hub_token = None,
|
| 209 |
+
hub_private_repo = None,
|
| 210 |
+
hub_always_push = False,
|
| 211 |
+
gradient_checkpointing = False,
|
| 212 |
+
gradient_checkpointing_kwargs = None,
|
| 213 |
+
include_inputs_for_metrics = False,
|
| 214 |
+
eval_do_concat_batches = True,
|
| 215 |
+
fp16_backend = 'auto',
|
| 216 |
+
evaluation_strategy = None,
|
| 217 |
+
push_to_hub_model_id = None,
|
| 218 |
+
push_to_hub_organization = None,
|
| 219 |
+
push_to_hub_token = None,
|
| 220 |
+
mp_parameters = '',
|
| 221 |
+
auto_find_batch_size = False,
|
| 222 |
+
full_determinism = False,
|
| 223 |
+
torchdynamo = None,
|
| 224 |
+
ray_scope = 'last',
|
| 225 |
+
ddp_timeout = 1800,
|
| 226 |
+
torch_compile = False,
|
| 227 |
+
torch_compile_backend = None,
|
| 228 |
+
torch_compile_mode = None,
|
| 229 |
+
dispatch_batches = None,
|
| 230 |
+
split_batches = None,
|
| 231 |
+
include_tokens_per_second = False,
|
| 232 |
+
include_num_input_tokens_seen = False,
|
| 233 |
+
neftune_noise_alpha = None,
|
| 234 |
+
optim_target_modules = None,
|
| 235 |
+
batch_eval_metrics = False,
|
| 236 |
+
eval_on_start = False,
|
| 237 |
+
use_liger_kernel = False,
|
| 238 |
+
eval_use_gather_object = False,
|
| 239 |
+
average_tokens_across_devices = False,
|
| 240 |
+
reward_model_path = None,
|
| 241 |
+
judge = None,
|
| 242 |
+
max_new_tokens = 64,
|
| 243 |
+
max_length = 512,
|
| 244 |
+
temperature = 0.9,
|
| 245 |
+
missing_eos_penalty = None,
|
| 246 |
+
loss_type = 'sigmoid',
|
| 247 |
+
dataset_num_proc = None,
|
| 248 |
+
disable_dropout = True,
|
| 249 |
+
use_vllm = False,
|
| 250 |
+
ds3_gather_for_generation = True,
|
| 251 |
+
vllm_sampling_params = None,
|
| 252 |
+
unsloth_num_chunks = -1,
|
| 253 |
+
**kwargs,
|
| 254 |
+
):
|
| 255 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 256 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 257 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 258 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 259 |
+
save_strategy = 'no'
|
| 260 |
+
if dataset_num_proc is None:
|
| 261 |
+
from multiprocessing import cpu_count
|
| 262 |
+
dataset_num_proc = cpu_count()
|
| 263 |
+
|
| 264 |
+
super().__init__(
|
| 265 |
+
output_dir = output_dir,
|
| 266 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 267 |
+
do_train = do_train,
|
| 268 |
+
do_eval = do_eval,
|
| 269 |
+
do_predict = do_predict,
|
| 270 |
+
eval_strategy = eval_strategy,
|
| 271 |
+
prediction_loss_only = prediction_loss_only,
|
| 272 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 273 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 274 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 275 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 276 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 277 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 278 |
+
eval_delay = eval_delay,
|
| 279 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 280 |
+
learning_rate = learning_rate,
|
| 281 |
+
weight_decay = weight_decay,
|
| 282 |
+
adam_beta1 = adam_beta1,
|
| 283 |
+
adam_beta2 = adam_beta2,
|
| 284 |
+
adam_epsilon = adam_epsilon,
|
| 285 |
+
max_grad_norm = max_grad_norm,
|
| 286 |
+
num_train_epochs = num_train_epochs,
|
| 287 |
+
max_steps = max_steps,
|
| 288 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 289 |
+
warmup_ratio = warmup_ratio,
|
| 290 |
+
warmup_steps = warmup_steps,
|
| 291 |
+
log_level = log_level,
|
| 292 |
+
log_level_replica = log_level_replica,
|
| 293 |
+
log_on_each_node = log_on_each_node,
|
| 294 |
+
logging_dir = logging_dir,
|
| 295 |
+
logging_strategy = logging_strategy,
|
| 296 |
+
logging_first_step = logging_first_step,
|
| 297 |
+
logging_steps = logging_steps,
|
| 298 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 299 |
+
save_strategy = save_strategy,
|
| 300 |
+
save_steps = save_steps,
|
| 301 |
+
save_total_limit = save_total_limit,
|
| 302 |
+
save_safetensors = save_safetensors,
|
| 303 |
+
save_on_each_node = save_on_each_node,
|
| 304 |
+
save_only_model = save_only_model,
|
| 305 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 306 |
+
no_cuda = no_cuda,
|
| 307 |
+
use_cpu = use_cpu,
|
| 308 |
+
use_mps_device = use_mps_device,
|
| 309 |
+
seed = seed,
|
| 310 |
+
data_seed = data_seed,
|
| 311 |
+
jit_mode_eval = jit_mode_eval,
|
| 312 |
+
use_ipex = use_ipex,
|
| 313 |
+
bf16 = bf16,
|
| 314 |
+
fp16 = fp16,
|
| 315 |
+
fp16_opt_level = fp16_opt_level,
|
| 316 |
+
half_precision_backend = half_precision_backend,
|
| 317 |
+
bf16_full_eval = bf16_full_eval,
|
| 318 |
+
fp16_full_eval = fp16_full_eval,
|
| 319 |
+
tf32 = tf32,
|
| 320 |
+
local_rank = local_rank,
|
| 321 |
+
ddp_backend = ddp_backend,
|
| 322 |
+
tpu_num_cores = tpu_num_cores,
|
| 323 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 324 |
+
debug = debug,
|
| 325 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 326 |
+
eval_steps = eval_steps,
|
| 327 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 328 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 329 |
+
past_index = past_index,
|
| 330 |
+
run_name = run_name,
|
| 331 |
+
disable_tqdm = disable_tqdm,
|
| 332 |
+
remove_unused_columns = remove_unused_columns,
|
| 333 |
+
label_names = label_names,
|
| 334 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 335 |
+
metric_for_best_model = metric_for_best_model,
|
| 336 |
+
greater_is_better = greater_is_better,
|
| 337 |
+
ignore_data_skip = ignore_data_skip,
|
| 338 |
+
fsdp = fsdp,
|
| 339 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 340 |
+
fsdp_config = fsdp_config,
|
| 341 |
+
tp_size = tp_size,
|
| 342 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 343 |
+
accelerator_config = accelerator_config,
|
| 344 |
+
deepspeed = deepspeed,
|
| 345 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 346 |
+
optim = optim,
|
| 347 |
+
optim_args = optim_args,
|
| 348 |
+
adafactor = adafactor,
|
| 349 |
+
group_by_length = group_by_length,
|
| 350 |
+
length_column_name = length_column_name,
|
| 351 |
+
report_to = report_to,
|
| 352 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 353 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 354 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 355 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 356 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 357 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 358 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 359 |
+
push_to_hub = push_to_hub,
|
| 360 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 361 |
+
hub_model_id = hub_model_id,
|
| 362 |
+
hub_strategy = hub_strategy,
|
| 363 |
+
hub_token = hub_token,
|
| 364 |
+
hub_private_repo = hub_private_repo,
|
| 365 |
+
hub_always_push = hub_always_push,
|
| 366 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 367 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 368 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 369 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 370 |
+
fp16_backend = fp16_backend,
|
| 371 |
+
evaluation_strategy = evaluation_strategy,
|
| 372 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 373 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 374 |
+
push_to_hub_token = push_to_hub_token,
|
| 375 |
+
mp_parameters = mp_parameters,
|
| 376 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 377 |
+
full_determinism = full_determinism,
|
| 378 |
+
torchdynamo = torchdynamo,
|
| 379 |
+
ray_scope = ray_scope,
|
| 380 |
+
ddp_timeout = ddp_timeout,
|
| 381 |
+
torch_compile = torch_compile,
|
| 382 |
+
torch_compile_backend = torch_compile_backend,
|
| 383 |
+
torch_compile_mode = torch_compile_mode,
|
| 384 |
+
dispatch_batches = dispatch_batches,
|
| 385 |
+
split_batches = split_batches,
|
| 386 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 387 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 388 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 389 |
+
optim_target_modules = optim_target_modules,
|
| 390 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 391 |
+
eval_on_start = eval_on_start,
|
| 392 |
+
use_liger_kernel = use_liger_kernel,
|
| 393 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 394 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 395 |
+
reward_model_path = reward_model_path,
|
| 396 |
+
judge = judge,
|
| 397 |
+
max_new_tokens = max_new_tokens,
|
| 398 |
+
max_length = max_length,
|
| 399 |
+
temperature = temperature,
|
| 400 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 401 |
+
loss_type = loss_type,
|
| 402 |
+
dataset_num_proc = dataset_num_proc,
|
| 403 |
+
disable_dropout = disable_dropout,
|
| 404 |
+
use_vllm = use_vllm,
|
| 405 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 406 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 407 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 408 |
+
pass
|
| 409 |
+
|
| 410 |
+
class _UnslothOnlineDPOTrainer(Trainer):
|
| 411 |
+
r""""""
|
| 412 |
+
|
| 413 |
+
_tag_names = ["trl", "online-dpo"]
|
| 414 |
+
|
| 415 |
+
def __init__(
|
| 416 |
+
self,
|
| 417 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 418 |
+
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 419 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 420 |
+
judge: Optional[BasePairwiseJudge] = None,
|
| 421 |
+
args: Optional[OnlineDPOConfig] = None,
|
| 422 |
+
data_collator: Optional[DataCollator] = None,
|
| 423 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 424 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 425 |
+
processing_class: Optional[
|
| 426 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 427 |
+
] = None,
|
| 428 |
+
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 429 |
+
peft_config: Optional[dict] = None,
|
| 430 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 431 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 432 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 433 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 434 |
+
) -> None:
|
| 435 |
+
|
| 436 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
| 437 |
+
if ref_model is model:
|
| 438 |
+
raise ValueError(
|
| 439 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 440 |
+
"same as `model`, either omit the `ref_model` argument or pass `None`."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
self.ref_model = ref_model
|
| 444 |
+
|
| 445 |
+
if reward_model is not None and judge is not None:
|
| 446 |
+
warnings.warn(
|
| 447 |
+
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
|
| 448 |
+
"Ignoring `judge` and using `reward_model`.",
|
| 449 |
+
UserWarning,
|
| 450 |
+
)
|
| 451 |
+
judge = None
|
| 452 |
+
elif reward_model is None and judge is None:
|
| 453 |
+
raise ValueError("Either `reward_model` or `judge` must be provided.")
|
| 454 |
+
|
| 455 |
+
self.reward_model = reward_model
|
| 456 |
+
self.reward_processing_class = reward_processing_class
|
| 457 |
+
self.judge = judge
|
| 458 |
+
|
| 459 |
+
if args.missing_eos_penalty is not None and judge is not None:
|
| 460 |
+
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
|
| 461 |
+
|
| 462 |
+
if args is None:
|
| 463 |
+
raise ValueError("`args` must be provided.")
|
| 464 |
+
|
| 465 |
+
# Check that the processing_class is provided
|
| 466 |
+
if processing_class is None:
|
| 467 |
+
raise ValueError("`processing_class` must be provided.")
|
| 468 |
+
|
| 469 |
+
# Convert to PEFT model if peft_config is provided
|
| 470 |
+
if False:
|
| 471 |
+
# Check if PEFT is available
|
| 472 |
+
if not is_peft_available():
|
| 473 |
+
raise ImportError(
|
| 474 |
+
"PEFT is not available and passed `peft_config`. Please install PEFT with "
|
| 475 |
+
"`pip install peft` to use it."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# If the model is already a PeftModel, we need to merge and unload it.
|
| 479 |
+
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
|
| 480 |
+
if isinstance(model, PeftModel):
|
| 481 |
+
model = model.merge_and_unload()
|
| 482 |
+
|
| 483 |
+
# Get peft model with the given config
|
| 484 |
+
model = model
|
| 485 |
+
|
| 486 |
+
# Disable dropout in the model and reference model
|
| 487 |
+
if args.disable_dropout:
|
| 488 |
+
disable_dropout_in_model(model)
|
| 489 |
+
if self.ref_model is not None:
|
| 490 |
+
disable_dropout_in_model(self.ref_model)
|
| 491 |
+
|
| 492 |
+
# Handle the ref_model
|
| 493 |
+
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
|
| 494 |
+
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
|
| 495 |
+
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
|
| 496 |
+
if ref_model is None: # No ref model provided, the most common case
|
| 497 |
+
if False:
|
| 498 |
+
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
|
| 499 |
+
else:
|
| 500 |
+
self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
|
| 501 |
+
else: # rare case, the user provided a ref model
|
| 502 |
+
self.ref_model = ref_model
|
| 503 |
+
self.ref_model.eval()
|
| 504 |
+
|
| 505 |
+
# Disable the gradient and set the reward model in eval mode
|
| 506 |
+
if self.reward_model is not None:
|
| 507 |
+
self.reward_model.eval()
|
| 508 |
+
|
| 509 |
+
# Define the collator is not provided
|
| 510 |
+
if data_collator is None:
|
| 511 |
+
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
|
| 512 |
+
|
| 513 |
+
self.max_length = args.max_length
|
| 514 |
+
|
| 515 |
+
self.stats = {
|
| 516 |
+
"objective/kl": [],
|
| 517 |
+
"objective/entropy": [],
|
| 518 |
+
"objective/non_score_reward": [],
|
| 519 |
+
"rewards/chosen": [],
|
| 520 |
+
"rewards/rejected": [],
|
| 521 |
+
"rewards/accuracies": [],
|
| 522 |
+
"rewards/margins": [],
|
| 523 |
+
"logps/chosen": [],
|
| 524 |
+
"logps/rejected": [],
|
| 525 |
+
"val/contain_eos_token": [],
|
| 526 |
+
"beta": [],
|
| 527 |
+
}
|
| 528 |
+
if self.reward_model is not None:
|
| 529 |
+
self.stats["objective/rlhf_reward"] = []
|
| 530 |
+
self.stats["objective/scores_margin"] = []
|
| 531 |
+
self.stats["objective/scores"] = []
|
| 532 |
+
|
| 533 |
+
if args.use_vllm:
|
| 534 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
|
| 535 |
+
n=2, max_tokens=args.max_new_tokens,
|
| 536 |
+
temperature=args.temperature,
|
| 537 |
+
top_k=50,
|
| 538 |
+
top_p=1.0,
|
| 539 |
+
detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
| 540 |
+
else:
|
| 541 |
+
self.generation_config = GenerationConfig(
|
| 542 |
+
max_new_tokens=args.max_new_tokens,
|
| 543 |
+
temperature=args.temperature,
|
| 544 |
+
top_k=50,
|
| 545 |
+
top_p=1.0,
|
| 546 |
+
do_sample=True,
|
| 547 |
+
use_cache=False if args.gradient_checkpointing else True,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 551 |
+
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
|
| 552 |
+
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 553 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 554 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 555 |
+
# that the warning has already been issued.
|
| 556 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 557 |
+
|
| 558 |
+
super().__init__(
|
| 559 |
+
model=model,
|
| 560 |
+
args=args,
|
| 561 |
+
data_collator=data_collator,
|
| 562 |
+
train_dataset=train_dataset,
|
| 563 |
+
eval_dataset=eval_dataset,
|
| 564 |
+
processing_class=processing_class,
|
| 565 |
+
compute_metrics=compute_metrics,
|
| 566 |
+
callbacks=callbacks,
|
| 567 |
+
optimizers=optimizers,
|
| 568 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 572 |
+
if hasattr(self.model, "add_model_tags"):
|
| 573 |
+
self.model.add_model_tags(self._tag_names)
|
| 574 |
+
|
| 575 |
+
self._beta = args.beta
|
| 576 |
+
|
| 577 |
+
# Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
|
| 578 |
+
if self.is_deepspeed_enabled:
|
| 579 |
+
if self.reward_model is not None:
|
| 580 |
+
self.reward_model = prepare_deepspeed(
|
| 581 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 582 |
+
)
|
| 583 |
+
if self.ref_model is not None:
|
| 584 |
+
self.ref_model = prepare_deepspeed(
|
| 585 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
if self.ref_model is not None:
|
| 589 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
| 590 |
+
if self.reward_model is not None:
|
| 591 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 592 |
+
|
| 593 |
+
@property
|
| 594 |
+
def beta(self):
|
| 595 |
+
if isinstance(self._beta, list):
|
| 596 |
+
epoch = self.state.epoch
|
| 597 |
+
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
|
| 598 |
+
else:
|
| 599 |
+
return self._beta
|
| 600 |
+
|
| 601 |
+
@staticmethod
|
| 602 |
+
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
|
| 603 |
+
"""Tokenize a single row from a DPO specific dataset."""
|
| 604 |
+
if not is_encoder_decoder:
|
| 605 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=False)
|
| 606 |
+
# Add BOS token to head of prompt. Avoid adding if it's already there
|
| 607 |
+
if tokenizer.bos_token_id is not None:
|
| 608 |
+
prompt_len_input_ids = len(batch["input_ids"])
|
| 609 |
+
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
|
| 610 |
+
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
|
| 611 |
+
batch["attention_mask"] = [1] + batch["attention_mask"]
|
| 612 |
+
else:
|
| 613 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=True)
|
| 614 |
+
batch = {f"prompt_{key}": value for key, value in batch.items()}
|
| 615 |
+
return batch
|
| 616 |
+
|
| 617 |
+
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
|
| 618 |
+
@wraps(Trainer.get_train_dataloader)
|
| 619 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 620 |
+
if self.train_dataset is None:
|
| 621 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
| 622 |
+
|
| 623 |
+
train_dataset = self.train_dataset
|
| 624 |
+
data_collator = self.data_collator
|
| 625 |
+
dataloader_params = {
|
| 626 |
+
"batch_size": self._train_batch_size,
|
| 627 |
+
"collate_fn": data_collator,
|
| 628 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 629 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 630 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
| 634 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
| 635 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 636 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
| 637 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 638 |
+
|
| 639 |
+
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
| 640 |
+
|
| 641 |
+
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
| 642 |
+
@wraps(Trainer.get_eval_dataloader)
|
| 643 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
| 644 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 645 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 646 |
+
|
| 647 |
+
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
| 648 |
+
# don't change during training
|
| 649 |
+
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
| 650 |
+
if (
|
| 651 |
+
hasattr(self, "_eval_dataloaders")
|
| 652 |
+
and dataloader_key in self._eval_dataloaders
|
| 653 |
+
and self.args.dataloader_persistent_workers
|
| 654 |
+
):
|
| 655 |
+
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
| 656 |
+
|
| 657 |
+
eval_dataset = (
|
| 658 |
+
self.eval_dataset[eval_dataset]
|
| 659 |
+
if isinstance(eval_dataset, str)
|
| 660 |
+
else eval_dataset
|
| 661 |
+
if eval_dataset is not None
|
| 662 |
+
else self.eval_dataset
|
| 663 |
+
)
|
| 664 |
+
data_collator = self.data_collator
|
| 665 |
+
|
| 666 |
+
dataloader_params = {
|
| 667 |
+
"batch_size": self.args.eval_batch_size,
|
| 668 |
+
"collate_fn": data_collator,
|
| 669 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 670 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 671 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
| 675 |
+
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
| 676 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 677 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 678 |
+
|
| 679 |
+
# accelerator.free_memory() will destroy the references, so
|
| 680 |
+
# we need to store the non-prepared version
|
| 681 |
+
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
| 682 |
+
if self.args.dataloader_persistent_workers:
|
| 683 |
+
if hasattr(self, "_eval_dataloaders"):
|
| 684 |
+
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
| 685 |
+
else:
|
| 686 |
+
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
| 687 |
+
|
| 688 |
+
return self.accelerator.prepare(eval_dataloader)
|
| 689 |
+
|
| 690 |
+
def _generate_vllm(self, model, prompts):
|
| 691 |
+
eos_token_id = self.processing_class.eos_token_id
|
| 692 |
+
pad_token_id = self.processing_class.pad_token_id
|
| 693 |
+
|
| 694 |
+
# Load the latest weights
|
| 695 |
+
|
| 696 |
+
pass
|
| 697 |
+
|
| 698 |
+
pass
|
| 699 |
+
|
| 700 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 701 |
+
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
| 702 |
+
else:
|
| 703 |
+
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
| 704 |
+
|
| 705 |
+
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
|
| 706 |
+
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
|
| 707 |
+
|
| 708 |
+
# Create mask and pad the prompt and completion
|
| 709 |
+
max_prompt_length = max(len(ids) for ids in prompt_ids)
|
| 710 |
+
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
|
| 711 |
+
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
|
| 712 |
+
max_tokens = self.generation_config.max_tokens
|
| 713 |
+
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
|
| 714 |
+
completion_ids = [
|
| 715 |
+
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
|
| 716 |
+
for ids in completion_ids
|
| 717 |
+
]
|
| 718 |
+
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
|
| 719 |
+
|
| 720 |
+
# Convert to tensors
|
| 721 |
+
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
|
| 722 |
+
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
|
| 723 |
+
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
|
| 724 |
+
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
|
| 725 |
+
|
| 726 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
| 727 |
+
|
| 728 |
+
def _generate(self, model, prompts):
|
| 729 |
+
eos_token_id = self.processing_class.eos_token_id
|
| 730 |
+
pad_token_id = self.processing_class.pad_token_id
|
| 731 |
+
|
| 732 |
+
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
|
| 733 |
+
# policies with different tokenizers / chat templates.
|
| 734 |
+
inputs = [{"prompt": prompt} for prompt in prompts]
|
| 735 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 736 |
+
inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 737 |
+
inputs = self.data_collator(inputs)
|
| 738 |
+
|
| 739 |
+
# Sample 2 completions per prompt of size `max_new_tokens` from the model
|
| 740 |
+
inputs = self._prepare_inputs(inputs)
|
| 741 |
+
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
|
| 742 |
+
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
|
| 743 |
+
with unwrap_model_for_generation(
|
| 744 |
+
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 745 |
+
) as unwrapped_model:
|
| 746 |
+
output = unwrapped_model.generate(
|
| 747 |
+
input_ids=prompt_ids,
|
| 748 |
+
attention_mask=prompt_mask,
|
| 749 |
+
generation_config=self.generation_config,
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
completion_ids = output[:, prompt_ids.size(1) :]
|
| 753 |
+
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
|
| 754 |
+
|
| 755 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
| 756 |
+
|
| 757 |
+
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
|
| 758 |
+
# Get the number of tokens to truncate from prompt
|
| 759 |
+
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
|
| 760 |
+
|
| 761 |
+
# Truncate left to avoid oom
|
| 762 |
+
prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
|
| 763 |
+
prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
|
| 764 |
+
|
| 765 |
+
# Concat the prompt and completion
|
| 766 |
+
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
| 767 |
+
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
| 768 |
+
|
| 769 |
+
# Get the logprobs of the completions from the model
|
| 770 |
+
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
|
| 771 |
+
|
| 772 |
+
# There is 1 offset, because the model predict the next token
|
| 773 |
+
logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
|
| 774 |
+
|
| 775 |
+
# Take the completion tokens logprob
|
| 776 |
+
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
| 777 |
+
return logprobs
|
| 778 |
+
|
| 779 |
+
def training_step(
|
| 780 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 781 |
+
) -> torch.Tensor:
|
| 782 |
+
model.train()
|
| 783 |
+
|
| 784 |
+
prompts = inputs["prompt"]
|
| 785 |
+
batch_size = len(prompts)
|
| 786 |
+
|
| 787 |
+
if self.args.use_vllm:
|
| 788 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
|
| 789 |
+
else:
|
| 790 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
|
| 791 |
+
|
| 792 |
+
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
|
| 793 |
+
|
| 794 |
+
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 795 |
+
with torch.no_grad():
|
| 796 |
+
if self.ref_model is not None:
|
| 797 |
+
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 798 |
+
else: # peft case: we just need to disable the adapter
|
| 799 |
+
with self.model.disable_adapter():
|
| 800 |
+
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 801 |
+
|
| 802 |
+
# Decode the completions, and format them if the input is conversational
|
| 803 |
+
device = logprobs.device
|
| 804 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 805 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 806 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
| 807 |
+
|
| 808 |
+
# Get the reward from the reward model or judge
|
| 809 |
+
if self.judge is not None:
|
| 810 |
+
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
|
| 811 |
+
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
|
| 812 |
+
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
|
| 813 |
+
# template to it.
|
| 814 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 815 |
+
environment = jinja2.Environment()
|
| 816 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 817 |
+
prompts = [template.render(messages=prompt) for prompt in prompts]
|
| 818 |
+
completions = [template.render(messages=completion) for completion in completions]
|
| 819 |
+
|
| 820 |
+
ranks_of_first_completion = self.judge.judge(
|
| 821 |
+
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# convert ranks to a True/False mask:
|
| 825 |
+
# when rank == 0, it means the first completion is the best
|
| 826 |
+
# when rank == 1, it means the second completion is the best
|
| 827 |
+
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
|
| 828 |
+
else:
|
| 829 |
+
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
|
| 830 |
+
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
|
| 831 |
+
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
|
| 832 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 833 |
+
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
|
| 834 |
+
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
|
| 835 |
+
prompts = [example["prompt"] for example in examples]
|
| 836 |
+
completions = [example["completion"] for example in examples]
|
| 837 |
+
|
| 838 |
+
# Tokenize the prompts
|
| 839 |
+
prompts_ids = self.reward_processing_class(
|
| 840 |
+
prompts, padding=True, return_tensors="pt", padding_side="left"
|
| 841 |
+
)["input_ids"].to(device)
|
| 842 |
+
context_length = prompts_ids.shape[1]
|
| 843 |
+
|
| 844 |
+
# Tokenize the completions
|
| 845 |
+
completions_ids = self.reward_processing_class(
|
| 846 |
+
completions, padding=True, return_tensors="pt", padding_side="right"
|
| 847 |
+
)["input_ids"].to(device)
|
| 848 |
+
|
| 849 |
+
# Concatenate the prompts and completions and get the reward
|
| 850 |
+
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
|
| 851 |
+
with torch.inference_mode():
|
| 852 |
+
_, scores, _ = get_reward(
|
| 853 |
+
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
# Filter completion. Ensure that the sample contains stop_token_id
|
| 857 |
+
# Completions not passing that filter will receive a lower score.
|
| 858 |
+
if self.args.missing_eos_penalty is not None:
|
| 859 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 860 |
+
|
| 861 |
+
# Split the scores in 2 (the prompts of the first half are the same as the second half)
|
| 862 |
+
first_half, second_half = scores.split(batch_size)
|
| 863 |
+
|
| 864 |
+
# Get the indices of the chosen and rejected examples
|
| 865 |
+
mask = first_half >= second_half
|
| 866 |
+
|
| 867 |
+
batch_range = torch.arange(batch_size, device=device)
|
| 868 |
+
chosen_indices = batch_range + (~mask * batch_size)
|
| 869 |
+
rejected_indices = batch_range + (mask * batch_size)
|
| 870 |
+
|
| 871 |
+
# Build tensor so that the first half is the chosen examples and the second half the rejected examples
|
| 872 |
+
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
|
| 873 |
+
cr_logprobs = logprobs[cr_indices]
|
| 874 |
+
cr_ref_logprobs = ref_logprobs[cr_indices]
|
| 875 |
+
|
| 876 |
+
# mask out the padding tokens
|
| 877 |
+
padding_mask = ~completion_mask.bool()
|
| 878 |
+
cr_padding_mask = padding_mask[cr_indices]
|
| 879 |
+
|
| 880 |
+
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
|
| 881 |
+
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
|
| 882 |
+
|
| 883 |
+
# Split the chosen and rejected examples
|
| 884 |
+
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
|
| 885 |
+
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
|
| 886 |
+
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
|
| 887 |
+
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
|
| 888 |
+
|
| 889 |
+
logits = pi_logratios - ref_logratios
|
| 890 |
+
|
| 891 |
+
if self.args.loss_type == "sigmoid":
|
| 892 |
+
losses = -F.logsigmoid(self.beta * logits)
|
| 893 |
+
elif self.args.loss_type == "ipo":
|
| 894 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 895 |
+
else:
|
| 896 |
+
raise NotImplementedError(f"invalid loss type {self.loss_type}")
|
| 897 |
+
|
| 898 |
+
loss = losses.mean()
|
| 899 |
+
|
| 900 |
+
# Log everything
|
| 901 |
+
if self.reward_model is not None:
|
| 902 |
+
scores_margin = scores[chosen_indices] - scores[rejected_indices]
|
| 903 |
+
self.stats["objective/scores_margin"].append(
|
| 904 |
+
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
|
| 905 |
+
)
|
| 906 |
+
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
|
| 907 |
+
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
|
| 908 |
+
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
|
| 909 |
+
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
|
| 910 |
+
|
| 911 |
+
kl = logprobs - ref_logprobs
|
| 912 |
+
mean_kl = kl.sum(1).mean()
|
| 913 |
+
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 914 |
+
non_score_reward = (-self.beta * kl).sum(1)
|
| 915 |
+
mean_non_score_reward = non_score_reward.mean()
|
| 916 |
+
self.stats["objective/non_score_reward"].append(
|
| 917 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 918 |
+
)
|
| 919 |
+
if self.reward_model is not None:
|
| 920 |
+
rlhf_reward = scores + non_score_reward
|
| 921 |
+
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
|
| 922 |
+
mean_entropy = -logprobs.sum(1).mean()
|
| 923 |
+
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
|
| 924 |
+
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
|
| 925 |
+
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
|
| 926 |
+
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
|
| 927 |
+
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
|
| 928 |
+
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
|
| 929 |
+
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
|
| 930 |
+
margin = gathered_chosen_rewards - gathered_rejected_rewards
|
| 931 |
+
self.stats["rewards/margins"].append(margin.mean().item())
|
| 932 |
+
accuracy = margin > 0
|
| 933 |
+
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
|
| 934 |
+
self.stats["beta"].append(self.beta)
|
| 935 |
+
|
| 936 |
+
if (
|
| 937 |
+
self.args.torch_empty_cache_steps is not None
|
| 938 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 939 |
+
):
|
| 940 |
+
empty_cache()
|
| 941 |
+
|
| 942 |
+
kwargs = {}
|
| 943 |
+
|
| 944 |
+
# For LOMO optimizers you need to explicitly use the learnign rate
|
| 945 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 946 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
| 947 |
+
|
| 948 |
+
if self.args.n_gpu > 1:
|
| 949 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 950 |
+
|
| 951 |
+
if self.use_apex:
|
| 952 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 953 |
+
scaled_loss.backward()
|
| 954 |
+
else:
|
| 955 |
+
self.accelerator.backward(loss, **kwargs)
|
| 956 |
+
|
| 957 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 958 |
+
|
| 959 |
+
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
| 960 |
+
# start_time defaults to None to allow compatibility with transformers<=4.46
|
| 961 |
+
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
|
| 962 |
+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
| 963 |
+
logs: dict[str, float] = {}
|
| 964 |
+
|
| 965 |
+
# all_gather + mean() to get average loss over all processes
|
| 966 |
+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
| 967 |
+
|
| 968 |
+
# reset tr_loss to zero
|
| 969 |
+
tr_loss -= tr_loss
|
| 970 |
+
|
| 971 |
+
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
| 972 |
+
if grad_norm is not None:
|
| 973 |
+
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
| 974 |
+
logs["learning_rate"] = self._get_learning_rate()
|
| 975 |
+
|
| 976 |
+
# Add our metrics
|
| 977 |
+
for key, val in self.stats.items():
|
| 978 |
+
logs[key] = sum(val) / len(val)
|
| 979 |
+
self.stats = {key: [] for key in self.stats} # reset stats
|
| 980 |
+
|
| 981 |
+
self._total_loss_scalar += tr_loss_scalar
|
| 982 |
+
self._globalstep_last_logged = self.state.global_step
|
| 983 |
+
self.store_flos()
|
| 984 |
+
|
| 985 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 986 |
+
self.log(logs, start_time)
|
| 987 |
+
else: # transformers<=4.46
|
| 988 |
+
self.log(logs)
|
| 989 |
+
|
| 990 |
+
metrics = None
|
| 991 |
+
if self.control.should_evaluate:
|
| 992 |
+
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
| 993 |
+
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
| 994 |
+
|
| 995 |
+
if self.args.save_strategy == "best":
|
| 996 |
+
self.control.should_save = is_new_best_metric
|
| 997 |
+
|
| 998 |
+
if self.control.should_save:
|
| 999 |
+
self._save_checkpoint(model, trial)
|
| 1000 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1001 |
+
|
| 1002 |
+
# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
|
| 1003 |
+
# This can be removed once the minimum transformers version is updated to 4.47.
|
| 1004 |
+
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
|
| 1005 |
+
def _determine_best_metric(self, metrics, trial):
|
| 1006 |
+
"""
|
| 1007 |
+
Determine if the model should be saved based on the evaluation metrics.
|
| 1008 |
+
If args.metric_for_best_model is not set, the loss is used.
|
| 1009 |
+
Returns:
|
| 1010 |
+
bool: True if a new best metric was found, else False
|
| 1011 |
+
"""
|
| 1012 |
+
is_new_best_metric = False
|
| 1013 |
+
|
| 1014 |
+
if self.args.metric_for_best_model is not None:
|
| 1015 |
+
metric_to_check = self.args.metric_for_best_model
|
| 1016 |
+
|
| 1017 |
+
if not metric_to_check.startswith("eval_"):
|
| 1018 |
+
metric_to_check = f"eval_{metric_to_check}"
|
| 1019 |
+
|
| 1020 |
+
try:
|
| 1021 |
+
metric_value = metrics[metric_to_check]
|
| 1022 |
+
except KeyError as exc:
|
| 1023 |
+
raise KeyError(
|
| 1024 |
+
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
| 1025 |
+
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
| 1026 |
+
) from exc
|
| 1027 |
+
|
| 1028 |
+
operator = np.greater if self.args.greater_is_better else np.less
|
| 1029 |
+
|
| 1030 |
+
if self.state.best_metric is None:
|
| 1031 |
+
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
| 1032 |
+
|
| 1033 |
+
if operator(metric_value, self.state.best_metric):
|
| 1034 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 1035 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 1036 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 1037 |
+
self.state.best_metric = metric_value
|
| 1038 |
+
self.state.best_model_checkpoint = output_dir
|
| 1039 |
+
|
| 1040 |
+
is_new_best_metric = True
|
| 1041 |
+
|
| 1042 |
+
return is_new_best_metric
|
| 1043 |
+
|
| 1044 |
+
def create_model_card(
|
| 1045 |
+
self,
|
| 1046 |
+
model_name: Optional[str] = None,
|
| 1047 |
+
dataset_name: Optional[str] = None,
|
| 1048 |
+
tags: Union[str, list[str], None] = None,
|
| 1049 |
+
):
|
| 1050 |
+
"""
|
| 1051 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1052 |
+
|
| 1053 |
+
Args:
|
| 1054 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1055 |
+
Name of the model.
|
| 1056 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1057 |
+
Name of the dataset used for training.
|
| 1058 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1059 |
+
Tags to be associated with the model card.
|
| 1060 |
+
"""
|
| 1061 |
+
if not self.is_world_process_zero():
|
| 1062 |
+
return
|
| 1063 |
+
|
| 1064 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1065 |
+
base_model = self.model.config._name_or_path
|
| 1066 |
+
else:
|
| 1067 |
+
base_model = None
|
| 1068 |
+
|
| 1069 |
+
tags = tags or []
|
| 1070 |
+
if isinstance(tags, str):
|
| 1071 |
+
tags = [tags]
|
| 1072 |
+
|
| 1073 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1074 |
+
tags.append("unsloth")
|
| 1075 |
+
|
| 1076 |
+
citation = textwrap.dedent("""\
|
| 1077 |
+
@article{guo2024direct,
|
| 1078 |
+
title = {{Direct Language Model Alignment from Online AI Feedback}},
|
| 1079 |
+
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
|
| 1080 |
+
year = 2024,
|
| 1081 |
+
eprint = {arXiv:2402.04792}
|
| 1082 |
+
}""")
|
| 1083 |
+
|
| 1084 |
+
model_card = generate_model_card(
|
| 1085 |
+
base_model=base_model,
|
| 1086 |
+
model_name=model_name,
|
| 1087 |
+
hub_model_id=self.hub_model_id,
|
| 1088 |
+
dataset_name=dataset_name,
|
| 1089 |
+
tags=tags,
|
| 1090 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1091 |
+
comet_url=get_comet_experiment_url(),
|
| 1092 |
+
trainer_name="Online DPO",
|
| 1093 |
+
trainer_citation=citation,
|
| 1094 |
+
paper_title="Direct Language Model Alignment from Online AI Feedback",
|
| 1095 |
+
paper_id="2402.04792",
|
| 1096 |
+
)
|
| 1097 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1098 |
+
class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
Initialize OnlineDPOTrainer.
|
| 1102 |
+
|
| 1103 |
+
Args:
|
| 1104 |
+
model (`transformers.PreTrainedModel` or `torch.nn.Module`):
|
| 1105 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
| 1106 |
+
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
| 1107 |
+
The reference model to use for training. If None is specified, the reference model will be created from
|
| 1108 |
+
the model.
|
| 1109 |
+
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
| 1110 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 1111 |
+
judge (`BasePairwiseJudge`):
|
| 1112 |
+
The judge to use for pairwise comparison of model completions.
|
| 1113 |
+
args (`OnlineDPOConfig`):
|
| 1114 |
+
The online DPO config arguments to use for training.
|
| 1115 |
+
data_collator (`transformers.DataCollator`):
|
| 1116 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1117 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1118 |
+
train_dataset (`datasets.Dataset`):
|
| 1119 |
+
The dataset to use for training.
|
| 1120 |
+
eval_dataset (`datasets.Dataset`):
|
| 1121 |
+
The dataset to use for evaluation.
|
| 1122 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1123 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1124 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1125 |
+
reuse the fine-tuned model.
|
| 1126 |
+
peft_config (`dict`):
|
| 1127 |
+
The peft config to use for training.
|
| 1128 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1129 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1130 |
+
a dictionary string to metric values.
|
| 1131 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1132 |
+
The callbacks to use for training.
|
| 1133 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1134 |
+
The optimizer and scheduler to use for training.
|
| 1135 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1136 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1137 |
+
|
| 1138 |
+
"""
|
| 1139 |
+
def __init__(
|
| 1140 |
+
self,
|
| 1141 |
+
model,
|
| 1142 |
+
ref_model = None,
|
| 1143 |
+
reward_model = None,
|
| 1144 |
+
judge = None,
|
| 1145 |
+
args = None,
|
| 1146 |
+
data_collator = None,
|
| 1147 |
+
train_dataset = None,
|
| 1148 |
+
eval_dataset = None,
|
| 1149 |
+
processing_class = None,
|
| 1150 |
+
reward_processing_class = None,
|
| 1151 |
+
peft_config = None,
|
| 1152 |
+
compute_metrics = None,
|
| 1153 |
+
callbacks = None,
|
| 1154 |
+
preprocess_logits_for_metrics = None,
|
| 1155 |
+
**kwargs
|
| 1156 |
+
):
|
| 1157 |
+
if args is None: args = UnslothOnlineDPOConfig()
|
| 1158 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1159 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1160 |
+
force_float32 = False
|
| 1161 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1162 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1163 |
+
force_float32 = True
|
| 1164 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1165 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1166 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1167 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1168 |
+
dtype = _get_dtype(dtype)
|
| 1169 |
+
float16 = dtype == torch.float16
|
| 1170 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1171 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1172 |
+
if force_float32:
|
| 1173 |
+
args.fp16 = False
|
| 1174 |
+
args.bf16 = False
|
| 1175 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1176 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1177 |
+
args.fp16 = float16
|
| 1178 |
+
args.bf16 = not float16
|
| 1179 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1180 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1181 |
+
args.eval_strategy = 'steps'
|
| 1182 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1183 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1184 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1185 |
+
from transformers import __version__ as transformers_version
|
| 1186 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1187 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1188 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1189 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1190 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1191 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1192 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1193 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1194 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1195 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1196 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1197 |
+
if force_float32:
|
| 1198 |
+
args.bf16_full_eval = False
|
| 1199 |
+
args.fp16_full_eval = False
|
| 1200 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1201 |
+
args.bf16_full_eval = True
|
| 1202 |
+
args.fp16_full_eval = False
|
| 1203 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1204 |
+
args.bf16_full_eval = args.bf16
|
| 1205 |
+
args.fp16_full_eval = args.fp16
|
| 1206 |
+
_output_logits = False
|
| 1207 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1208 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1209 |
+
if _output_logits:
|
| 1210 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1211 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1212 |
+
pass
|
| 1213 |
+
else:
|
| 1214 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1215 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1216 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1217 |
+
max_seq_length = model.max_seq_length
|
| 1218 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1219 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1220 |
+
model.for_training()
|
| 1221 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1222 |
+
if 'processing_class' in locals():
|
| 1223 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1224 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1225 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1226 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1227 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1228 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1229 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1230 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1231 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1232 |
+
else:
|
| 1233 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1234 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1235 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1236 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1237 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1238 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1239 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1240 |
+
else:
|
| 1241 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1242 |
+
other_metrics = []
|
| 1243 |
+
|
| 1244 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1245 |
+
PatchRLStatistics('online_dpo_trainer', other_metrics)
|
| 1246 |
+
|
| 1247 |
+
super().__init__(
|
| 1248 |
+
model = model,
|
| 1249 |
+
ref_model = ref_model,
|
| 1250 |
+
reward_model = reward_model,
|
| 1251 |
+
judge = judge,
|
| 1252 |
+
args = args,
|
| 1253 |
+
data_collator = data_collator,
|
| 1254 |
+
train_dataset = train_dataset,
|
| 1255 |
+
eval_dataset = eval_dataset,
|
| 1256 |
+
processing_class = processing_class,
|
| 1257 |
+
reward_processing_class = reward_processing_class,
|
| 1258 |
+
peft_config = peft_config,
|
| 1259 |
+
compute_metrics = compute_metrics,
|
| 1260 |
+
callbacks = callbacks,
|
| 1261 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 1262 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1263 |
+
self.neftune_hook_handle.remove()
|
| 1264 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1265 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1266 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1267 |
+
pass
|
| 1268 |
+
|
| 1269 |
+
pass
|
unsloth_compiled_cache/UnslothPPOTrainer.py
ADDED
|
@@ -0,0 +1,1259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothPPOConfig(PPOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`PPOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
| 54 |
+
Name of this experiment.
|
| 55 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
| 56 |
+
Path to the reward model.
|
| 57 |
+
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
| 58 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 59 |
+
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
| 60 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 61 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
| 62 |
+
Number of epochs to train.
|
| 63 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether to whiten the rewards.
|
| 65 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
| 66 |
+
KL coefficient.
|
| 67 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
| 68 |
+
Clip range.
|
| 69 |
+
vf_coef (`float`, *optional*, defaults to `0.1`):
|
| 70 |
+
Value function coefficient.
|
| 71 |
+
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
| 72 |
+
Clip range for the value function.
|
| 73 |
+
gamma (`float`, *optional*, defaults to `1.0`):
|
| 74 |
+
Discount factor.
|
| 75 |
+
lam (`float`, *optional*, defaults to `0.95`):
|
| 76 |
+
Lambda value for GAE.
|
| 77 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 79 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 80 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 84 |
+
default = None,
|
| 85 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 86 |
+
)
|
| 87 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 88 |
+
default = -1,
|
| 89 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 90 |
+
)
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
output_dir = None,
|
| 94 |
+
overwrite_output_dir = None,
|
| 95 |
+
do_train = False,
|
| 96 |
+
do_eval = False,
|
| 97 |
+
do_predict = False,
|
| 98 |
+
eval_strategy = 'no',
|
| 99 |
+
prediction_loss_only = False,
|
| 100 |
+
per_device_train_batch_size = 4,
|
| 101 |
+
per_device_eval_batch_size = 4,
|
| 102 |
+
per_gpu_train_batch_size = None,
|
| 103 |
+
per_gpu_eval_batch_size = None,
|
| 104 |
+
gradient_accumulation_steps = 2,
|
| 105 |
+
eval_accumulation_steps = 2,
|
| 106 |
+
eval_delay = 0,
|
| 107 |
+
torch_empty_cache_steps = 250,
|
| 108 |
+
learning_rate = 5e-05,
|
| 109 |
+
weight_decay = 0.01,
|
| 110 |
+
adam_beta1 = 0.9,
|
| 111 |
+
adam_beta2 = 0.999,
|
| 112 |
+
adam_epsilon = 1e-08,
|
| 113 |
+
max_grad_norm = 1.0,
|
| 114 |
+
num_train_epochs = 3.0,
|
| 115 |
+
max_steps = -1,
|
| 116 |
+
lr_scheduler_type = 'linear',
|
| 117 |
+
warmup_ratio = 0.1,
|
| 118 |
+
warmup_steps = 0,
|
| 119 |
+
log_level = 'passive',
|
| 120 |
+
log_level_replica = 'warning',
|
| 121 |
+
log_on_each_node = True,
|
| 122 |
+
logging_dir = None,
|
| 123 |
+
logging_strategy = 'steps',
|
| 124 |
+
logging_first_step = False,
|
| 125 |
+
logging_steps = 1,
|
| 126 |
+
logging_nan_inf_filter = False,
|
| 127 |
+
save_strategy = 'steps',
|
| 128 |
+
save_steps = 500,
|
| 129 |
+
save_total_limit = None,
|
| 130 |
+
save_safetensors = True,
|
| 131 |
+
save_on_each_node = False,
|
| 132 |
+
save_only_model = False,
|
| 133 |
+
restore_callback_states_from_checkpoint = False,
|
| 134 |
+
no_cuda = False,
|
| 135 |
+
use_cpu = False,
|
| 136 |
+
use_mps_device = False,
|
| 137 |
+
seed = 3407,
|
| 138 |
+
data_seed = 3407,
|
| 139 |
+
jit_mode_eval = False,
|
| 140 |
+
use_ipex = False,
|
| 141 |
+
bf16 = False,
|
| 142 |
+
fp16 = False,
|
| 143 |
+
fp16_opt_level = 'O1',
|
| 144 |
+
half_precision_backend = 'auto',
|
| 145 |
+
bf16_full_eval = False,
|
| 146 |
+
fp16_full_eval = False,
|
| 147 |
+
tf32 = None,
|
| 148 |
+
local_rank = -1,
|
| 149 |
+
ddp_backend = None,
|
| 150 |
+
tpu_num_cores = None,
|
| 151 |
+
tpu_metrics_debug = False,
|
| 152 |
+
debug = '',
|
| 153 |
+
dataloader_drop_last = False,
|
| 154 |
+
eval_steps = None,
|
| 155 |
+
dataloader_num_workers = 0,
|
| 156 |
+
dataloader_prefetch_factor = None,
|
| 157 |
+
past_index = -1,
|
| 158 |
+
run_name = None,
|
| 159 |
+
disable_tqdm = None,
|
| 160 |
+
remove_unused_columns = True,
|
| 161 |
+
label_names = None,
|
| 162 |
+
load_best_model_at_end = False,
|
| 163 |
+
metric_for_best_model = None,
|
| 164 |
+
greater_is_better = None,
|
| 165 |
+
ignore_data_skip = False,
|
| 166 |
+
fsdp = '',
|
| 167 |
+
fsdp_min_num_params = 0,
|
| 168 |
+
fsdp_config = None,
|
| 169 |
+
tp_size = 0,
|
| 170 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 171 |
+
accelerator_config = None,
|
| 172 |
+
deepspeed = None,
|
| 173 |
+
label_smoothing_factor = 0.0,
|
| 174 |
+
optim = 'adamw_8bit',
|
| 175 |
+
optim_args = None,
|
| 176 |
+
adafactor = False,
|
| 177 |
+
group_by_length = False,
|
| 178 |
+
length_column_name = 'length',
|
| 179 |
+
report_to = None,
|
| 180 |
+
ddp_find_unused_parameters = None,
|
| 181 |
+
ddp_bucket_cap_mb = None,
|
| 182 |
+
ddp_broadcast_buffers = None,
|
| 183 |
+
dataloader_pin_memory = True,
|
| 184 |
+
dataloader_persistent_workers = False,
|
| 185 |
+
skip_memory_metrics = True,
|
| 186 |
+
use_legacy_prediction_loop = False,
|
| 187 |
+
push_to_hub = False,
|
| 188 |
+
resume_from_checkpoint = None,
|
| 189 |
+
hub_model_id = None,
|
| 190 |
+
hub_strategy = 'every_save',
|
| 191 |
+
hub_token = None,
|
| 192 |
+
hub_private_repo = None,
|
| 193 |
+
hub_always_push = False,
|
| 194 |
+
gradient_checkpointing = False,
|
| 195 |
+
gradient_checkpointing_kwargs = None,
|
| 196 |
+
include_inputs_for_metrics = False,
|
| 197 |
+
eval_do_concat_batches = True,
|
| 198 |
+
fp16_backend = 'auto',
|
| 199 |
+
evaluation_strategy = None,
|
| 200 |
+
push_to_hub_model_id = None,
|
| 201 |
+
push_to_hub_organization = None,
|
| 202 |
+
push_to_hub_token = None,
|
| 203 |
+
mp_parameters = '',
|
| 204 |
+
auto_find_batch_size = False,
|
| 205 |
+
full_determinism = False,
|
| 206 |
+
torchdynamo = None,
|
| 207 |
+
ray_scope = 'last',
|
| 208 |
+
ddp_timeout = 1800,
|
| 209 |
+
torch_compile = False,
|
| 210 |
+
torch_compile_backend = None,
|
| 211 |
+
torch_compile_mode = None,
|
| 212 |
+
dispatch_batches = None,
|
| 213 |
+
split_batches = None,
|
| 214 |
+
include_tokens_per_second = False,
|
| 215 |
+
include_num_input_tokens_seen = False,
|
| 216 |
+
neftune_noise_alpha = None,
|
| 217 |
+
optim_target_modules = None,
|
| 218 |
+
batch_eval_metrics = False,
|
| 219 |
+
eval_on_start = False,
|
| 220 |
+
use_liger_kernel = False,
|
| 221 |
+
eval_use_gather_object = False,
|
| 222 |
+
average_tokens_across_devices = False,
|
| 223 |
+
dataset_num_proc = None,
|
| 224 |
+
num_mini_batches = 1,
|
| 225 |
+
total_episodes = None,
|
| 226 |
+
local_rollout_forward_batch_size = 64,
|
| 227 |
+
num_sample_generations = 10,
|
| 228 |
+
response_length = 53,
|
| 229 |
+
stop_token = None,
|
| 230 |
+
stop_token_id = None,
|
| 231 |
+
temperature = 0.7,
|
| 232 |
+
missing_eos_penalty = None,
|
| 233 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
| 234 |
+
world_size = None,
|
| 235 |
+
num_total_batches = None,
|
| 236 |
+
micro_batch_size = None,
|
| 237 |
+
local_batch_size = None,
|
| 238 |
+
batch_size = None,
|
| 239 |
+
local_mini_batch_size = None,
|
| 240 |
+
mini_batch_size = None,
|
| 241 |
+
exp_name = 'ppo_config',
|
| 242 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
| 243 |
+
model_adapter_name = None,
|
| 244 |
+
ref_adapter_name = None,
|
| 245 |
+
num_ppo_epochs = 4,
|
| 246 |
+
whiten_rewards = False,
|
| 247 |
+
kl_coef = 0.05,
|
| 248 |
+
cliprange = 0.2,
|
| 249 |
+
vf_coef = 0.1,
|
| 250 |
+
cliprange_value = 0.2,
|
| 251 |
+
gamma = 1.0,
|
| 252 |
+
lam = 0.95,
|
| 253 |
+
ds3_gather_for_generation = True,
|
| 254 |
+
vllm_sampling_params = None,
|
| 255 |
+
unsloth_num_chunks = -1,
|
| 256 |
+
**kwargs,
|
| 257 |
+
):
|
| 258 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 259 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 260 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 261 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 262 |
+
save_strategy = 'no'
|
| 263 |
+
if dataset_num_proc is None:
|
| 264 |
+
from multiprocessing import cpu_count
|
| 265 |
+
dataset_num_proc = cpu_count()
|
| 266 |
+
|
| 267 |
+
super().__init__(
|
| 268 |
+
output_dir = output_dir,
|
| 269 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 270 |
+
do_train = do_train,
|
| 271 |
+
do_eval = do_eval,
|
| 272 |
+
do_predict = do_predict,
|
| 273 |
+
eval_strategy = eval_strategy,
|
| 274 |
+
prediction_loss_only = prediction_loss_only,
|
| 275 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 276 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 277 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 278 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 279 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 280 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 281 |
+
eval_delay = eval_delay,
|
| 282 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 283 |
+
learning_rate = learning_rate,
|
| 284 |
+
weight_decay = weight_decay,
|
| 285 |
+
adam_beta1 = adam_beta1,
|
| 286 |
+
adam_beta2 = adam_beta2,
|
| 287 |
+
adam_epsilon = adam_epsilon,
|
| 288 |
+
max_grad_norm = max_grad_norm,
|
| 289 |
+
num_train_epochs = num_train_epochs,
|
| 290 |
+
max_steps = max_steps,
|
| 291 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 292 |
+
warmup_ratio = warmup_ratio,
|
| 293 |
+
warmup_steps = warmup_steps,
|
| 294 |
+
log_level = log_level,
|
| 295 |
+
log_level_replica = log_level_replica,
|
| 296 |
+
log_on_each_node = log_on_each_node,
|
| 297 |
+
logging_dir = logging_dir,
|
| 298 |
+
logging_strategy = logging_strategy,
|
| 299 |
+
logging_first_step = logging_first_step,
|
| 300 |
+
logging_steps = logging_steps,
|
| 301 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 302 |
+
save_strategy = save_strategy,
|
| 303 |
+
save_steps = save_steps,
|
| 304 |
+
save_total_limit = save_total_limit,
|
| 305 |
+
save_safetensors = save_safetensors,
|
| 306 |
+
save_on_each_node = save_on_each_node,
|
| 307 |
+
save_only_model = save_only_model,
|
| 308 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 309 |
+
no_cuda = no_cuda,
|
| 310 |
+
use_cpu = use_cpu,
|
| 311 |
+
use_mps_device = use_mps_device,
|
| 312 |
+
seed = seed,
|
| 313 |
+
data_seed = data_seed,
|
| 314 |
+
jit_mode_eval = jit_mode_eval,
|
| 315 |
+
use_ipex = use_ipex,
|
| 316 |
+
bf16 = bf16,
|
| 317 |
+
fp16 = fp16,
|
| 318 |
+
fp16_opt_level = fp16_opt_level,
|
| 319 |
+
half_precision_backend = half_precision_backend,
|
| 320 |
+
bf16_full_eval = bf16_full_eval,
|
| 321 |
+
fp16_full_eval = fp16_full_eval,
|
| 322 |
+
tf32 = tf32,
|
| 323 |
+
local_rank = local_rank,
|
| 324 |
+
ddp_backend = ddp_backend,
|
| 325 |
+
tpu_num_cores = tpu_num_cores,
|
| 326 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 327 |
+
debug = debug,
|
| 328 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 329 |
+
eval_steps = eval_steps,
|
| 330 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 331 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 332 |
+
past_index = past_index,
|
| 333 |
+
run_name = run_name,
|
| 334 |
+
disable_tqdm = disable_tqdm,
|
| 335 |
+
remove_unused_columns = remove_unused_columns,
|
| 336 |
+
label_names = label_names,
|
| 337 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 338 |
+
metric_for_best_model = metric_for_best_model,
|
| 339 |
+
greater_is_better = greater_is_better,
|
| 340 |
+
ignore_data_skip = ignore_data_skip,
|
| 341 |
+
fsdp = fsdp,
|
| 342 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 343 |
+
fsdp_config = fsdp_config,
|
| 344 |
+
tp_size = tp_size,
|
| 345 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 346 |
+
accelerator_config = accelerator_config,
|
| 347 |
+
deepspeed = deepspeed,
|
| 348 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 349 |
+
optim = optim,
|
| 350 |
+
optim_args = optim_args,
|
| 351 |
+
adafactor = adafactor,
|
| 352 |
+
group_by_length = group_by_length,
|
| 353 |
+
length_column_name = length_column_name,
|
| 354 |
+
report_to = report_to,
|
| 355 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 356 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 357 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 358 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 359 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 360 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 361 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 362 |
+
push_to_hub = push_to_hub,
|
| 363 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 364 |
+
hub_model_id = hub_model_id,
|
| 365 |
+
hub_strategy = hub_strategy,
|
| 366 |
+
hub_token = hub_token,
|
| 367 |
+
hub_private_repo = hub_private_repo,
|
| 368 |
+
hub_always_push = hub_always_push,
|
| 369 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 370 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 371 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 372 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 373 |
+
fp16_backend = fp16_backend,
|
| 374 |
+
evaluation_strategy = evaluation_strategy,
|
| 375 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 376 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 377 |
+
push_to_hub_token = push_to_hub_token,
|
| 378 |
+
mp_parameters = mp_parameters,
|
| 379 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 380 |
+
full_determinism = full_determinism,
|
| 381 |
+
torchdynamo = torchdynamo,
|
| 382 |
+
ray_scope = ray_scope,
|
| 383 |
+
ddp_timeout = ddp_timeout,
|
| 384 |
+
torch_compile = torch_compile,
|
| 385 |
+
torch_compile_backend = torch_compile_backend,
|
| 386 |
+
torch_compile_mode = torch_compile_mode,
|
| 387 |
+
dispatch_batches = dispatch_batches,
|
| 388 |
+
split_batches = split_batches,
|
| 389 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 390 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 391 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 392 |
+
optim_target_modules = optim_target_modules,
|
| 393 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 394 |
+
eval_on_start = eval_on_start,
|
| 395 |
+
use_liger_kernel = use_liger_kernel,
|
| 396 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 397 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 398 |
+
dataset_num_proc = dataset_num_proc,
|
| 399 |
+
num_mini_batches = num_mini_batches,
|
| 400 |
+
total_episodes = total_episodes,
|
| 401 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
| 402 |
+
num_sample_generations = num_sample_generations,
|
| 403 |
+
response_length = response_length,
|
| 404 |
+
stop_token = stop_token,
|
| 405 |
+
stop_token_id = stop_token_id,
|
| 406 |
+
temperature = temperature,
|
| 407 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 408 |
+
sft_model_path = sft_model_path,
|
| 409 |
+
world_size = world_size,
|
| 410 |
+
num_total_batches = num_total_batches,
|
| 411 |
+
micro_batch_size = micro_batch_size,
|
| 412 |
+
local_batch_size = local_batch_size,
|
| 413 |
+
batch_size = batch_size,
|
| 414 |
+
local_mini_batch_size = local_mini_batch_size,
|
| 415 |
+
mini_batch_size = mini_batch_size,
|
| 416 |
+
exp_name = exp_name,
|
| 417 |
+
reward_model_path = reward_model_path,
|
| 418 |
+
model_adapter_name = model_adapter_name,
|
| 419 |
+
ref_adapter_name = ref_adapter_name,
|
| 420 |
+
num_ppo_epochs = num_ppo_epochs,
|
| 421 |
+
whiten_rewards = whiten_rewards,
|
| 422 |
+
kl_coef = kl_coef,
|
| 423 |
+
cliprange = cliprange,
|
| 424 |
+
vf_coef = vf_coef,
|
| 425 |
+
cliprange_value = cliprange_value,
|
| 426 |
+
gamma = gamma,
|
| 427 |
+
lam = lam,
|
| 428 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 429 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 430 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 431 |
+
pass
|
| 432 |
+
|
| 433 |
+
class _UnslothPPOTrainer(Trainer):
|
| 434 |
+
_tag_names = ["trl", "ppo"]
|
| 435 |
+
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
args: PPOConfig,
|
| 439 |
+
processing_class: Optional[
|
| 440 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 441 |
+
],
|
| 442 |
+
model: nn.Module,
|
| 443 |
+
ref_model: Optional[nn.Module],
|
| 444 |
+
reward_model: nn.Module,
|
| 445 |
+
train_dataset: Dataset,
|
| 446 |
+
value_model: Optional[nn.Module] = None,
|
| 447 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
| 448 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 449 |
+
# less commonly used
|
| 450 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 451 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 452 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 453 |
+
) -> None:
|
| 454 |
+
if ref_model is model:
|
| 455 |
+
raise ValueError(
|
| 456 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 457 |
+
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
self.args = args
|
| 461 |
+
self.processing_class = processing_class
|
| 462 |
+
self.policy_model = model
|
| 463 |
+
|
| 464 |
+
# Define the collator if not provided
|
| 465 |
+
if data_collator is None:
|
| 466 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
| 467 |
+
|
| 468 |
+
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
| 469 |
+
if args.stop_token and args.stop_token_id:
|
| 470 |
+
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
| 471 |
+
elif args.stop_token:
|
| 472 |
+
if args.stop_token == "eos":
|
| 473 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
| 474 |
+
else:
|
| 475 |
+
raise ValueError(
|
| 476 |
+
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
| 480 |
+
|
| 481 |
+
# peft support
|
| 482 |
+
if not is_peft_available() and peft_config is not None:
|
| 483 |
+
raise ImportError(
|
| 484 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 485 |
+
)
|
| 486 |
+
elif is_peft_available() and peft_config is not None:
|
| 487 |
+
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
| 488 |
+
if isinstance(self.policy_model, PeftModel):
|
| 489 |
+
self.policy_model = self.policy_model.merge_and_unload()
|
| 490 |
+
|
| 491 |
+
# get peft model with the given config
|
| 492 |
+
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
| 493 |
+
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
| 494 |
+
peft_module_casting_to_bf16(self.policy_model)
|
| 495 |
+
|
| 496 |
+
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
| 497 |
+
self.model_adapter_name = args.model_adapter_name
|
| 498 |
+
self.ref_adapter_name = args.ref_adapter_name
|
| 499 |
+
|
| 500 |
+
if ref_model:
|
| 501 |
+
self.ref_model = ref_model
|
| 502 |
+
elif self.is_peft_model:
|
| 503 |
+
self.ref_model = None
|
| 504 |
+
else:
|
| 505 |
+
self.ref_model = create_reference_model(self.policy_model)
|
| 506 |
+
|
| 507 |
+
self.reward_model = reward_model
|
| 508 |
+
self.train_dataset = train_dataset
|
| 509 |
+
self.train_dataset_len = len(train_dataset)
|
| 510 |
+
self.value_model = value_model
|
| 511 |
+
self.data_collator = data_collator
|
| 512 |
+
self.eval_dataset = eval_dataset
|
| 513 |
+
self.optimizer, self.lr_scheduler = optimizers
|
| 514 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
| 515 |
+
|
| 516 |
+
#########
|
| 517 |
+
# calculate various batch sizes
|
| 518 |
+
#########
|
| 519 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
| 520 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
| 521 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
| 522 |
+
self.accelerator = accelerator
|
| 523 |
+
args.world_size = accelerator.num_processes
|
| 524 |
+
args.local_batch_size = (
|
| 525 |
+
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
| 526 |
+
)
|
| 527 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
| 528 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
| 529 |
+
args.mini_batch_size = exact_div(
|
| 530 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
| 531 |
+
)
|
| 532 |
+
args.local_mini_batch_size = exact_div(
|
| 533 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
| 534 |
+
)
|
| 535 |
+
if args.whiten_rewards:
|
| 536 |
+
assert (
|
| 537 |
+
args.local_mini_batch_size >= 8
|
| 538 |
+
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
| 539 |
+
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
| 540 |
+
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
| 541 |
+
args.num_total_batches = math.ceil(
|
| 542 |
+
args.total_episodes / args.batch_size
|
| 543 |
+
) # we may train for more than `total_episodes`
|
| 544 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
| 545 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
| 546 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
| 547 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
| 548 |
+
if args.num_sample_generations > 0:
|
| 549 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
| 550 |
+
self.local_dataloader_batch_size = args.local_batch_size
|
| 551 |
+
|
| 552 |
+
#########
|
| 553 |
+
# setup model, optimizer, and others
|
| 554 |
+
#########
|
| 555 |
+
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
| 556 |
+
if module is not None:
|
| 557 |
+
disable_dropout_in_model(module)
|
| 558 |
+
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
| 559 |
+
self.model.config = self.policy_model.config # needed for pushing to hub
|
| 560 |
+
self.create_optimizer_and_scheduler(
|
| 561 |
+
num_training_steps=args.num_total_batches
|
| 562 |
+
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
| 563 |
+
|
| 564 |
+
#########
|
| 565 |
+
### trainer specifics
|
| 566 |
+
#########
|
| 567 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
| 568 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
| 569 |
+
self.callback_handler = CallbackHandler(
|
| 570 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
| 571 |
+
)
|
| 572 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
| 573 |
+
self.control = TrainerControl()
|
| 574 |
+
self.state = OnlineTrainerState(
|
| 575 |
+
is_local_process_zero=self.is_local_process_zero(),
|
| 576 |
+
is_world_process_zero=self.is_world_process_zero(),
|
| 577 |
+
stateful_callbacks=[
|
| 578 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
| 579 |
+
],
|
| 580 |
+
)
|
| 581 |
+
self.current_flos = 0
|
| 582 |
+
self.hp_search_backend = None
|
| 583 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 584 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 585 |
+
# Create distant repo and output directory if needed
|
| 586 |
+
self.hub_model_id = None
|
| 587 |
+
if self.args.push_to_hub:
|
| 588 |
+
self.init_hf_repo()
|
| 589 |
+
if self.args.should_save:
|
| 590 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 591 |
+
|
| 592 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 593 |
+
if hasattr(self.model, "add_model_tags"):
|
| 594 |
+
self.model.add_model_tags(self._tag_names)
|
| 595 |
+
|
| 596 |
+
#########
|
| 597 |
+
### setup dataloader
|
| 598 |
+
#########
|
| 599 |
+
self.dataloader = DataLoader(
|
| 600 |
+
self.train_dataset,
|
| 601 |
+
batch_size=self.local_dataloader_batch_size,
|
| 602 |
+
shuffle=True,
|
| 603 |
+
collate_fn=self.data_collator,
|
| 604 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
| 605 |
+
)
|
| 606 |
+
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
| 607 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
| 608 |
+
torch.manual_seed(args.seed)
|
| 609 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
| 610 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
| 611 |
+
|
| 612 |
+
self.eval_dataloader = DataLoader(
|
| 613 |
+
self.eval_dataset,
|
| 614 |
+
batch_size=args.per_device_eval_batch_size,
|
| 615 |
+
collate_fn=self.data_collator,
|
| 616 |
+
drop_last=True,
|
| 617 |
+
) # no need to shuffle eval dataset
|
| 618 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
| 619 |
+
|
| 620 |
+
if self.is_deepspeed_enabled:
|
| 621 |
+
self.reward_model = prepare_deepspeed(
|
| 622 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
if self.ref_model is None:
|
| 626 |
+
if not self.is_peft_model:
|
| 627 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
| 628 |
+
else:
|
| 629 |
+
self.ref_model = prepare_deepspeed(
|
| 630 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
if self.ref_model is None:
|
| 634 |
+
if not self.is_peft_model:
|
| 635 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
| 636 |
+
else:
|
| 637 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
| 638 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 639 |
+
|
| 640 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 641 |
+
return self.dataloader
|
| 642 |
+
|
| 643 |
+
def get_eval_dataloader(self) -> DataLoader:
|
| 644 |
+
return self.eval_dataloader
|
| 645 |
+
|
| 646 |
+
@contextmanager
|
| 647 |
+
def null_ref_context(self):
|
| 648 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 649 |
+
with (
|
| 650 |
+
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
| 651 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 652 |
+
else nullcontext()
|
| 653 |
+
):
|
| 654 |
+
if self.ref_adapter_name:
|
| 655 |
+
self.model.policy.set_adapter(self.ref_adapter_name)
|
| 656 |
+
yield
|
| 657 |
+
if self.ref_adapter_name:
|
| 658 |
+
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
| 659 |
+
|
| 660 |
+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
| 661 |
+
backup_model = self.model
|
| 662 |
+
self.model = self.model.policy # save only the policy
|
| 663 |
+
|
| 664 |
+
if self.is_deepspeed_enabled:
|
| 665 |
+
backup_deepspeed = self.deepspeed
|
| 666 |
+
self.deepspeed = self.model
|
| 667 |
+
|
| 668 |
+
super().save_model(output_dir, _internal_call)
|
| 669 |
+
|
| 670 |
+
self.model = backup_model
|
| 671 |
+
|
| 672 |
+
if self.is_deepspeed_enabled:
|
| 673 |
+
self.deepspeed = backup_deepspeed
|
| 674 |
+
|
| 675 |
+
def train(self):
|
| 676 |
+
args = self.args
|
| 677 |
+
accelerator = self.accelerator
|
| 678 |
+
optimizer = self.optimizer
|
| 679 |
+
model = self.model
|
| 680 |
+
ref_policy = self.ref_model
|
| 681 |
+
reward_model = self.reward_model
|
| 682 |
+
processing_class = self.processing_class
|
| 683 |
+
dataloader = self.dataloader
|
| 684 |
+
device = accelerator.device
|
| 685 |
+
|
| 686 |
+
def repeat_generator():
|
| 687 |
+
while True:
|
| 688 |
+
yield from dataloader
|
| 689 |
+
|
| 690 |
+
iter_dataloader = iter(repeat_generator())
|
| 691 |
+
generation_config = GenerationConfig(
|
| 692 |
+
max_new_tokens=args.response_length,
|
| 693 |
+
temperature=(args.temperature + 1e-7),
|
| 694 |
+
top_k=0.0,
|
| 695 |
+
top_p=1.0,
|
| 696 |
+
do_sample=True,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
accelerator.print("===training policy===")
|
| 700 |
+
start_time = time.time()
|
| 701 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
| 702 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
| 703 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 704 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
| 705 |
+
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
| 706 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 707 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
| 708 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
| 709 |
+
model.train()
|
| 710 |
+
|
| 711 |
+
# trainer state initialization
|
| 712 |
+
self.state.global_step = 0
|
| 713 |
+
self.state.episode = 0
|
| 714 |
+
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
| 715 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
| 716 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
| 717 |
+
if args.logging_steps is not None:
|
| 718 |
+
if args.logging_steps < 1:
|
| 719 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
| 720 |
+
else:
|
| 721 |
+
self.state.logging_steps = args.logging_steps
|
| 722 |
+
if args.eval_steps is not None:
|
| 723 |
+
if args.eval_steps < 1:
|
| 724 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
| 725 |
+
else:
|
| 726 |
+
self.state.eval_steps = args.eval_steps
|
| 727 |
+
if args.save_steps is not None:
|
| 728 |
+
if args.save_steps < 1:
|
| 729 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
| 730 |
+
else:
|
| 731 |
+
self.state.save_steps = args.save_steps
|
| 732 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
| 733 |
+
|
| 734 |
+
# backward compatibility
|
| 735 |
+
if self.is_deepspeed_enabled:
|
| 736 |
+
self.deepspeed = self.model
|
| 737 |
+
self.model_wrapped = self.model
|
| 738 |
+
|
| 739 |
+
for update in range(1, args.num_total_batches + 1):
|
| 740 |
+
self.state.episode += 1 * args.batch_size
|
| 741 |
+
data = next(iter_dataloader)
|
| 742 |
+
with torch.no_grad():
|
| 743 |
+
queries = data["input_ids"].to(device)
|
| 744 |
+
context_length = queries.shape[1]
|
| 745 |
+
responses = []
|
| 746 |
+
postprocessed_responses = []
|
| 747 |
+
logprobs = []
|
| 748 |
+
ref_logprobs = []
|
| 749 |
+
scores = []
|
| 750 |
+
sequence_lengths = []
|
| 751 |
+
values = []
|
| 752 |
+
with unwrap_model_for_generation(
|
| 753 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 754 |
+
) as unwrapped_model:
|
| 755 |
+
query_responses, logitss = batch_generation(
|
| 756 |
+
unwrapped_model.policy,
|
| 757 |
+
queries,
|
| 758 |
+
args.local_rollout_forward_batch_size,
|
| 759 |
+
processing_class.pad_token_id,
|
| 760 |
+
generation_config,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
| 764 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
| 765 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
| 766 |
+
response = query_response[:, context_length:]
|
| 767 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
| 768 |
+
logprob = selective_log_softmax(logits, response)
|
| 769 |
+
del logits
|
| 770 |
+
torch.cuda.empty_cache()
|
| 771 |
+
|
| 772 |
+
if ref_policy is None:
|
| 773 |
+
with self.null_ref_context():
|
| 774 |
+
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
| 775 |
+
else:
|
| 776 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
| 777 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
| 778 |
+
ref_logits /= args.temperature + 1e-7
|
| 779 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
| 780 |
+
del ref_output, ref_logits
|
| 781 |
+
torch.cuda.empty_cache()
|
| 782 |
+
|
| 783 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
| 784 |
+
postprocessed_response = response
|
| 785 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 786 |
+
postprocessed_response = truncate_response(
|
| 787 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
# Response Processing 2. run reward model on the truncated responses
|
| 791 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 792 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
| 793 |
+
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
| 794 |
+
full_value, _, _ = get_reward(
|
| 795 |
+
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
| 796 |
+
)
|
| 797 |
+
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
| 798 |
+
_, score, _ = get_reward(
|
| 799 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
responses.append(response)
|
| 803 |
+
postprocessed_responses.append(postprocessed_response)
|
| 804 |
+
logprobs.append(logprob)
|
| 805 |
+
ref_logprobs.append(ref_logprob)
|
| 806 |
+
sequence_lengths.append(sequence_length)
|
| 807 |
+
scores.append(score)
|
| 808 |
+
values.append(value)
|
| 809 |
+
responses = torch.cat(responses, 0)
|
| 810 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
| 811 |
+
logprobs = torch.cat(logprobs, 0)
|
| 812 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
| 813 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
| 814 |
+
scores = torch.cat(scores, 0)
|
| 815 |
+
values = torch.cat(values, 0)
|
| 816 |
+
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
| 817 |
+
torch.cuda.empty_cache()
|
| 818 |
+
gc.collect()
|
| 819 |
+
|
| 820 |
+
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
| 821 |
+
# Completions not passing that filter will receive a lower score.
|
| 822 |
+
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
| 823 |
+
if self.args.missing_eos_penalty is not None:
|
| 824 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 825 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
| 826 |
+
|
| 827 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
| 828 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
| 829 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
| 830 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
| 831 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
| 832 |
+
sequence_lengths_p1 = sequence_lengths + 1
|
| 833 |
+
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
| 834 |
+
values = torch.masked_fill(values, padding_mask_p1, 0)
|
| 835 |
+
|
| 836 |
+
# 4. compute rewards
|
| 837 |
+
kl = logprobs - ref_logprobs
|
| 838 |
+
non_score_reward = -args.kl_coef * kl
|
| 839 |
+
rewards = non_score_reward.clone()
|
| 840 |
+
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
| 841 |
+
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
| 842 |
+
rewards[[actual_start, actual_end]] += scores
|
| 843 |
+
|
| 844 |
+
# 5. whiten rewards
|
| 845 |
+
if args.whiten_rewards:
|
| 846 |
+
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
| 847 |
+
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
| 848 |
+
|
| 849 |
+
# 6. compute advantages and returns
|
| 850 |
+
lastgaelam = 0
|
| 851 |
+
advantages_reversed = []
|
| 852 |
+
gen_length = responses.shape[1]
|
| 853 |
+
for t in reversed(range(gen_length)):
|
| 854 |
+
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
| 855 |
+
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
| 856 |
+
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
| 857 |
+
advantages_reversed.append(lastgaelam)
|
| 858 |
+
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
| 859 |
+
returns = advantages + values
|
| 860 |
+
advantages = masked_whiten(advantages, ~padding_mask)
|
| 861 |
+
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
| 862 |
+
torch.cuda.empty_cache()
|
| 863 |
+
|
| 864 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
| 865 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
| 866 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
| 867 |
+
minibatch_idx = 0
|
| 868 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
| 869 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
| 870 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
| 871 |
+
gradient_accumulation_idx = 0
|
| 872 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
| 873 |
+
with accelerator.accumulate(model):
|
| 874 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
| 875 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
| 876 |
+
mb_advantage = advantages[micro_batch_inds]
|
| 877 |
+
mb_responses = responses[micro_batch_inds]
|
| 878 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
| 879 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
| 880 |
+
mb_return = returns[micro_batch_inds]
|
| 881 |
+
mb_values = values[micro_batch_inds]
|
| 882 |
+
|
| 883 |
+
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
| 884 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 885 |
+
logits /= args.temperature + 1e-7
|
| 886 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
| 887 |
+
new_logprobs = torch.masked_fill(
|
| 888 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
| 889 |
+
)
|
| 890 |
+
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
| 891 |
+
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
| 892 |
+
vpredclipped = torch.clamp(
|
| 893 |
+
vpred,
|
| 894 |
+
mb_values - args.cliprange_value,
|
| 895 |
+
mb_values + args.cliprange_value,
|
| 896 |
+
)
|
| 897 |
+
vf_losses1 = torch.square(vpred - mb_return)
|
| 898 |
+
vf_losses2 = torch.square(vpredclipped - mb_return)
|
| 899 |
+
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
| 900 |
+
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
| 901 |
+
vf_clipfrac = masked_mean(
|
| 902 |
+
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
| 903 |
+
)
|
| 904 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
| 905 |
+
ratio = torch.exp(logprobs_diff)
|
| 906 |
+
pg_losses = -mb_advantage * ratio
|
| 907 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
| 908 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
| 909 |
+
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
| 910 |
+
loss = pg_loss + args.vf_coef * vf_loss
|
| 911 |
+
accelerator.backward(loss)
|
| 912 |
+
optimizer.step()
|
| 913 |
+
optimizer.zero_grad()
|
| 914 |
+
with torch.no_grad():
|
| 915 |
+
pg_clipfrac = masked_mean(
|
| 916 |
+
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
| 917 |
+
)
|
| 918 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
| 919 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
| 920 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
| 921 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
| 922 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 923 |
+
pg_clipfrac
|
| 924 |
+
)
|
| 925 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
| 926 |
+
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
| 927 |
+
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 928 |
+
vf_clipfrac
|
| 929 |
+
)
|
| 930 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
| 931 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
| 932 |
+
gradient_accumulation_idx += 1
|
| 933 |
+
minibatch_idx += 1
|
| 934 |
+
# del everything and empty cache
|
| 935 |
+
# fmt: off
|
| 936 |
+
del (
|
| 937 |
+
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
| 938 |
+
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
| 939 |
+
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
| 940 |
+
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
| 941 |
+
)
|
| 942 |
+
# fmt: on
|
| 943 |
+
torch.cuda.empty_cache()
|
| 944 |
+
with torch.no_grad():
|
| 945 |
+
mean_kl = kl.sum(1).mean()
|
| 946 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
| 947 |
+
mean_non_score_reward = non_score_reward.sum(1).mean()
|
| 948 |
+
rlhf_reward = mean_non_score_reward + scores.mean()
|
| 949 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
| 950 |
+
metrics = {}
|
| 951 |
+
metrics["eps"] = eps
|
| 952 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 953 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
| 954 |
+
metrics["objective/non_score_reward"] = (
|
| 955 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 956 |
+
)
|
| 957 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
| 958 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
| 959 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
| 960 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
| 961 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
| 962 |
+
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
| 963 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
| 964 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
| 965 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
| 966 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
| 967 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
| 968 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
| 969 |
+
metrics["episode"] = self.state.episode
|
| 970 |
+
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
| 971 |
+
self.state.global_step += 1
|
| 972 |
+
self.log(metrics)
|
| 973 |
+
|
| 974 |
+
self.lr_scheduler.step()
|
| 975 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| 976 |
+
if self.control.should_save:
|
| 977 |
+
self._save_checkpoint(model, trial=None)
|
| 978 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 979 |
+
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
| 980 |
+
torch.cuda.empty_cache()
|
| 981 |
+
gc.collect()
|
| 982 |
+
|
| 983 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
| 984 |
+
self.generate_completions(sampling=True)
|
| 985 |
+
torch.cuda.empty_cache()
|
| 986 |
+
del (
|
| 987 |
+
query_responses,
|
| 988 |
+
responses,
|
| 989 |
+
postprocessed_responses,
|
| 990 |
+
logprobs,
|
| 991 |
+
ref_logprobs,
|
| 992 |
+
values,
|
| 993 |
+
sequence_lengths,
|
| 994 |
+
contain_eos_token,
|
| 995 |
+
sequence_lengths_p1,
|
| 996 |
+
response_idxs,
|
| 997 |
+
padding_mask,
|
| 998 |
+
padding_mask_p1,
|
| 999 |
+
rewards,
|
| 1000 |
+
actual_start,
|
| 1001 |
+
actual_end,
|
| 1002 |
+
advantages,
|
| 1003 |
+
returns,
|
| 1004 |
+
)
|
| 1005 |
+
torch.cuda.empty_cache()
|
| 1006 |
+
|
| 1007 |
+
# HF trainer specifics
|
| 1008 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| 1009 |
+
if self.control.should_save:
|
| 1010 |
+
self._save_checkpoint(model, trial=None, metrics=None)
|
| 1011 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1012 |
+
|
| 1013 |
+
def generate_completions(self, sampling: bool = False):
|
| 1014 |
+
args = self.args
|
| 1015 |
+
processing_class = self.processing_class
|
| 1016 |
+
generation_config = GenerationConfig(
|
| 1017 |
+
max_new_tokens=self.args.response_length,
|
| 1018 |
+
temperature=(0.01 + 1e-7),
|
| 1019 |
+
top_k=0.0,
|
| 1020 |
+
top_p=1.0,
|
| 1021 |
+
do_sample=True,
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
table = defaultdict(list)
|
| 1025 |
+
with unwrap_model_for_generation(
|
| 1026 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 1027 |
+
) as unwrapped_model:
|
| 1028 |
+
for batch in self.eval_dataloader:
|
| 1029 |
+
query = batch["input_ids"]
|
| 1030 |
+
with torch.no_grad():
|
| 1031 |
+
context_length = query.shape[1]
|
| 1032 |
+
query_response, _ = batch_generation(
|
| 1033 |
+
unwrapped_model.policy,
|
| 1034 |
+
query,
|
| 1035 |
+
query.shape[0],
|
| 1036 |
+
processing_class.pad_token_id,
|
| 1037 |
+
generation_config,
|
| 1038 |
+
)
|
| 1039 |
+
response = query_response[:, context_length:]
|
| 1040 |
+
postprocessed_response = response
|
| 1041 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 1042 |
+
postprocessed_response = truncate_response(
|
| 1043 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
| 1044 |
+
)
|
| 1045 |
+
table["query"].extend(
|
| 1046 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
| 1047 |
+
)
|
| 1048 |
+
table["model response"].extend(
|
| 1049 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 1053 |
+
_, score, _ = get_reward(
|
| 1054 |
+
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 1055 |
+
)
|
| 1056 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
| 1057 |
+
|
| 1058 |
+
if sampling:
|
| 1059 |
+
break
|
| 1060 |
+
df = pd.DataFrame(table)
|
| 1061 |
+
|
| 1062 |
+
if self.accelerator.is_main_process:
|
| 1063 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
| 1064 |
+
if "wandb" in args.report_to:
|
| 1065 |
+
import wandb
|
| 1066 |
+
|
| 1067 |
+
if wandb.run is not None:
|
| 1068 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 1069 |
+
|
| 1070 |
+
if "comet_ml" in args.report_to:
|
| 1071 |
+
log_table_to_comet_experiment(
|
| 1072 |
+
name="completions.csv",
|
| 1073 |
+
table=df,
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
def create_model_card(
|
| 1077 |
+
self,
|
| 1078 |
+
model_name: Optional[str] = None,
|
| 1079 |
+
dataset_name: Optional[str] = None,
|
| 1080 |
+
tags: Union[str, list[str], None] = None,
|
| 1081 |
+
):
|
| 1082 |
+
"""
|
| 1083 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1084 |
+
|
| 1085 |
+
Args:
|
| 1086 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1087 |
+
Name of the model.
|
| 1088 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1089 |
+
Name of the dataset used for training.
|
| 1090 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1091 |
+
Tags to be associated with the model card.
|
| 1092 |
+
"""
|
| 1093 |
+
if not self.is_world_process_zero():
|
| 1094 |
+
return
|
| 1095 |
+
|
| 1096 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1097 |
+
base_model = self.model.config._name_or_path
|
| 1098 |
+
else:
|
| 1099 |
+
base_model = None
|
| 1100 |
+
|
| 1101 |
+
tags = tags or []
|
| 1102 |
+
if isinstance(tags, str):
|
| 1103 |
+
tags = [tags]
|
| 1104 |
+
|
| 1105 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1106 |
+
tags.append("unsloth")
|
| 1107 |
+
|
| 1108 |
+
citation = textwrap.dedent("""\
|
| 1109 |
+
@article{mziegler2019fine-tuning,
|
| 1110 |
+
title = {{Fine-Tuning Language Models from Human Preferences}},
|
| 1111 |
+
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
| 1112 |
+
year = 2019,
|
| 1113 |
+
eprint = {arXiv:1909.08593}
|
| 1114 |
+
}""")
|
| 1115 |
+
|
| 1116 |
+
model_card = generate_model_card(
|
| 1117 |
+
base_model=base_model,
|
| 1118 |
+
model_name=model_name,
|
| 1119 |
+
hub_model_id=self.hub_model_id,
|
| 1120 |
+
dataset_name=dataset_name,
|
| 1121 |
+
tags=tags,
|
| 1122 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1123 |
+
comet_url=get_comet_experiment_url(),
|
| 1124 |
+
trainer_name="PPO",
|
| 1125 |
+
trainer_citation=citation,
|
| 1126 |
+
paper_title="Fine-Tuning Language Models from Human Preferences",
|
| 1127 |
+
paper_id="1909.08593",
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1131 |
+
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
| 1132 |
+
"""
|
| 1133 |
+
|
| 1134 |
+
"""
|
| 1135 |
+
def __init__(
|
| 1136 |
+
self,
|
| 1137 |
+
args,
|
| 1138 |
+
processing_class,
|
| 1139 |
+
model,
|
| 1140 |
+
ref_model,
|
| 1141 |
+
reward_model,
|
| 1142 |
+
train_dataset,
|
| 1143 |
+
value_model = None,
|
| 1144 |
+
data_collator = None,
|
| 1145 |
+
eval_dataset = None,
|
| 1146 |
+
callbacks = None,
|
| 1147 |
+
peft_config = None,
|
| 1148 |
+
**kwargs
|
| 1149 |
+
):
|
| 1150 |
+
if args is None: args = UnslothPPOConfig()
|
| 1151 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1152 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1153 |
+
force_float32 = False
|
| 1154 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1155 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1156 |
+
force_float32 = True
|
| 1157 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1158 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1159 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1160 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1161 |
+
dtype = _get_dtype(dtype)
|
| 1162 |
+
float16 = dtype == torch.float16
|
| 1163 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1164 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1165 |
+
if force_float32:
|
| 1166 |
+
args.fp16 = False
|
| 1167 |
+
args.bf16 = False
|
| 1168 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1169 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1170 |
+
args.fp16 = float16
|
| 1171 |
+
args.bf16 = not float16
|
| 1172 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1173 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1174 |
+
args.eval_strategy = 'steps'
|
| 1175 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1176 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1177 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1178 |
+
from transformers import __version__ as transformers_version
|
| 1179 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1180 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1181 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1182 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1183 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1184 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1185 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1186 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1187 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1188 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1189 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1190 |
+
if force_float32:
|
| 1191 |
+
args.bf16_full_eval = False
|
| 1192 |
+
args.fp16_full_eval = False
|
| 1193 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1194 |
+
args.bf16_full_eval = True
|
| 1195 |
+
args.fp16_full_eval = False
|
| 1196 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1197 |
+
args.bf16_full_eval = args.bf16
|
| 1198 |
+
args.fp16_full_eval = args.fp16
|
| 1199 |
+
_output_logits = False
|
| 1200 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1201 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1202 |
+
if _output_logits:
|
| 1203 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1204 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1205 |
+
pass
|
| 1206 |
+
else:
|
| 1207 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1208 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1209 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1210 |
+
max_seq_length = model.max_seq_length
|
| 1211 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1212 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1213 |
+
model.for_training()
|
| 1214 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1215 |
+
if 'processing_class' in locals():
|
| 1216 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1217 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1218 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1219 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1220 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1221 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1222 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1223 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1224 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1225 |
+
else:
|
| 1226 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1227 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1228 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1229 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1230 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1231 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1232 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1233 |
+
else:
|
| 1234 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1235 |
+
other_metrics = []
|
| 1236 |
+
|
| 1237 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1238 |
+
PatchRLStatistics('ppo_trainer', other_metrics)
|
| 1239 |
+
|
| 1240 |
+
super().__init__(
|
| 1241 |
+
args = args,
|
| 1242 |
+
processing_class = processing_class,
|
| 1243 |
+
model = model,
|
| 1244 |
+
ref_model = ref_model,
|
| 1245 |
+
reward_model = reward_model,
|
| 1246 |
+
train_dataset = train_dataset,
|
| 1247 |
+
value_model = value_model,
|
| 1248 |
+
data_collator = data_collator,
|
| 1249 |
+
eval_dataset = eval_dataset,
|
| 1250 |
+
callbacks = callbacks,
|
| 1251 |
+
peft_config = peft_config,**kwargs)
|
| 1252 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1253 |
+
self.neftune_hook_handle.remove()
|
| 1254 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1255 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1256 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1257 |
+
pass
|
| 1258 |
+
|
| 1259 |
+
pass
|
unsloth_compiled_cache/UnslothPRMTrainer.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothPRMConfig(PRMConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`PRMTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
| 54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 55 |
+
[`~transformers.TrainingArguments`].
|
| 56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 57 |
+
Maximum length of the sequences (prompt + completion) used for truncation.
|
| 58 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 59 |
+
Maximum length of the prompt used for truncation.
|
| 60 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 61 |
+
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
| 62 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether to disable dropout in the model.
|
| 64 |
+
step_separator (`str`, *optional*, defaults to `"\n"`):
|
| 65 |
+
Separator used to separate each step of the reasoning process.
|
| 66 |
+
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Whether to train only on the last step.
|
| 68 |
+
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
| 69 |
+
Number of processes to use for processing the dataset.
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 73 |
+
default = None,
|
| 74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 75 |
+
)
|
| 76 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 77 |
+
default = -1,
|
| 78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 79 |
+
)
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
output_dir = None,
|
| 83 |
+
overwrite_output_dir = None,
|
| 84 |
+
do_train = False,
|
| 85 |
+
do_eval = False,
|
| 86 |
+
do_predict = False,
|
| 87 |
+
eval_strategy = 'no',
|
| 88 |
+
prediction_loss_only = False,
|
| 89 |
+
per_device_train_batch_size = 4,
|
| 90 |
+
per_device_eval_batch_size = 4,
|
| 91 |
+
per_gpu_train_batch_size = None,
|
| 92 |
+
per_gpu_eval_batch_size = None,
|
| 93 |
+
gradient_accumulation_steps = 2,
|
| 94 |
+
eval_accumulation_steps = 2,
|
| 95 |
+
eval_delay = 0,
|
| 96 |
+
torch_empty_cache_steps = 250,
|
| 97 |
+
learning_rate = 5e-05,
|
| 98 |
+
weight_decay = 0.01,
|
| 99 |
+
adam_beta1 = 0.9,
|
| 100 |
+
adam_beta2 = 0.999,
|
| 101 |
+
adam_epsilon = 1e-08,
|
| 102 |
+
max_grad_norm = 1.0,
|
| 103 |
+
num_train_epochs = 3.0,
|
| 104 |
+
max_steps = -1,
|
| 105 |
+
lr_scheduler_type = 'linear',
|
| 106 |
+
warmup_ratio = 0.1,
|
| 107 |
+
warmup_steps = 0,
|
| 108 |
+
log_level = 'passive',
|
| 109 |
+
log_level_replica = 'warning',
|
| 110 |
+
log_on_each_node = True,
|
| 111 |
+
logging_dir = None,
|
| 112 |
+
logging_strategy = 'steps',
|
| 113 |
+
logging_first_step = False,
|
| 114 |
+
logging_steps = 1,
|
| 115 |
+
logging_nan_inf_filter = False,
|
| 116 |
+
save_strategy = 'steps',
|
| 117 |
+
save_steps = 500,
|
| 118 |
+
save_total_limit = None,
|
| 119 |
+
save_safetensors = True,
|
| 120 |
+
save_on_each_node = False,
|
| 121 |
+
save_only_model = False,
|
| 122 |
+
restore_callback_states_from_checkpoint = False,
|
| 123 |
+
no_cuda = False,
|
| 124 |
+
use_cpu = False,
|
| 125 |
+
use_mps_device = False,
|
| 126 |
+
seed = 3407,
|
| 127 |
+
data_seed = 3407,
|
| 128 |
+
jit_mode_eval = False,
|
| 129 |
+
use_ipex = False,
|
| 130 |
+
bf16 = False,
|
| 131 |
+
fp16 = False,
|
| 132 |
+
fp16_opt_level = 'O1',
|
| 133 |
+
half_precision_backend = 'auto',
|
| 134 |
+
bf16_full_eval = False,
|
| 135 |
+
fp16_full_eval = False,
|
| 136 |
+
tf32 = None,
|
| 137 |
+
local_rank = -1,
|
| 138 |
+
ddp_backend = None,
|
| 139 |
+
tpu_num_cores = None,
|
| 140 |
+
tpu_metrics_debug = False,
|
| 141 |
+
debug = '',
|
| 142 |
+
dataloader_drop_last = False,
|
| 143 |
+
eval_steps = None,
|
| 144 |
+
dataloader_num_workers = 0,
|
| 145 |
+
dataloader_prefetch_factor = None,
|
| 146 |
+
past_index = -1,
|
| 147 |
+
run_name = None,
|
| 148 |
+
disable_tqdm = None,
|
| 149 |
+
remove_unused_columns = True,
|
| 150 |
+
label_names = None,
|
| 151 |
+
load_best_model_at_end = False,
|
| 152 |
+
metric_for_best_model = None,
|
| 153 |
+
greater_is_better = None,
|
| 154 |
+
ignore_data_skip = False,
|
| 155 |
+
fsdp = '',
|
| 156 |
+
fsdp_min_num_params = 0,
|
| 157 |
+
fsdp_config = None,
|
| 158 |
+
tp_size = 0,
|
| 159 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 160 |
+
accelerator_config = None,
|
| 161 |
+
deepspeed = None,
|
| 162 |
+
label_smoothing_factor = 0.0,
|
| 163 |
+
optim = 'adamw_8bit',
|
| 164 |
+
optim_args = None,
|
| 165 |
+
adafactor = False,
|
| 166 |
+
group_by_length = False,
|
| 167 |
+
length_column_name = 'length',
|
| 168 |
+
report_to = None,
|
| 169 |
+
ddp_find_unused_parameters = None,
|
| 170 |
+
ddp_bucket_cap_mb = None,
|
| 171 |
+
ddp_broadcast_buffers = None,
|
| 172 |
+
dataloader_pin_memory = True,
|
| 173 |
+
dataloader_persistent_workers = False,
|
| 174 |
+
skip_memory_metrics = True,
|
| 175 |
+
use_legacy_prediction_loop = False,
|
| 176 |
+
push_to_hub = False,
|
| 177 |
+
resume_from_checkpoint = None,
|
| 178 |
+
hub_model_id = None,
|
| 179 |
+
hub_strategy = 'every_save',
|
| 180 |
+
hub_token = None,
|
| 181 |
+
hub_private_repo = None,
|
| 182 |
+
hub_always_push = False,
|
| 183 |
+
gradient_checkpointing = False,
|
| 184 |
+
gradient_checkpointing_kwargs = None,
|
| 185 |
+
include_inputs_for_metrics = False,
|
| 186 |
+
eval_do_concat_batches = True,
|
| 187 |
+
fp16_backend = 'auto',
|
| 188 |
+
evaluation_strategy = None,
|
| 189 |
+
push_to_hub_model_id = None,
|
| 190 |
+
push_to_hub_organization = None,
|
| 191 |
+
push_to_hub_token = None,
|
| 192 |
+
mp_parameters = '',
|
| 193 |
+
auto_find_batch_size = False,
|
| 194 |
+
full_determinism = False,
|
| 195 |
+
torchdynamo = None,
|
| 196 |
+
ray_scope = 'last',
|
| 197 |
+
ddp_timeout = 1800,
|
| 198 |
+
torch_compile = False,
|
| 199 |
+
torch_compile_backend = None,
|
| 200 |
+
torch_compile_mode = None,
|
| 201 |
+
dispatch_batches = None,
|
| 202 |
+
split_batches = None,
|
| 203 |
+
include_tokens_per_second = False,
|
| 204 |
+
include_num_input_tokens_seen = False,
|
| 205 |
+
neftune_noise_alpha = None,
|
| 206 |
+
optim_target_modules = None,
|
| 207 |
+
batch_eval_metrics = False,
|
| 208 |
+
eval_on_start = False,
|
| 209 |
+
use_liger_kernel = False,
|
| 210 |
+
eval_use_gather_object = False,
|
| 211 |
+
average_tokens_across_devices = False,
|
| 212 |
+
max_length = 1024,
|
| 213 |
+
max_prompt_length = 512,
|
| 214 |
+
max_completion_length = None,
|
| 215 |
+
disable_dropout = True,
|
| 216 |
+
step_separator = '\
|
| 217 |
+
',
|
| 218 |
+
train_on_last_step_only = False,
|
| 219 |
+
dataset_num_proc = None,
|
| 220 |
+
vllm_sampling_params = None,
|
| 221 |
+
unsloth_num_chunks = -1,
|
| 222 |
+
**kwargs,
|
| 223 |
+
):
|
| 224 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 225 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 226 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 227 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 228 |
+
save_strategy = 'no'
|
| 229 |
+
if dataset_num_proc is None:
|
| 230 |
+
from multiprocessing import cpu_count
|
| 231 |
+
dataset_num_proc = cpu_count()
|
| 232 |
+
|
| 233 |
+
super().__init__(
|
| 234 |
+
output_dir = output_dir,
|
| 235 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 236 |
+
do_train = do_train,
|
| 237 |
+
do_eval = do_eval,
|
| 238 |
+
do_predict = do_predict,
|
| 239 |
+
eval_strategy = eval_strategy,
|
| 240 |
+
prediction_loss_only = prediction_loss_only,
|
| 241 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 242 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 243 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 244 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 245 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 246 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 247 |
+
eval_delay = eval_delay,
|
| 248 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 249 |
+
learning_rate = learning_rate,
|
| 250 |
+
weight_decay = weight_decay,
|
| 251 |
+
adam_beta1 = adam_beta1,
|
| 252 |
+
adam_beta2 = adam_beta2,
|
| 253 |
+
adam_epsilon = adam_epsilon,
|
| 254 |
+
max_grad_norm = max_grad_norm,
|
| 255 |
+
num_train_epochs = num_train_epochs,
|
| 256 |
+
max_steps = max_steps,
|
| 257 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 258 |
+
warmup_ratio = warmup_ratio,
|
| 259 |
+
warmup_steps = warmup_steps,
|
| 260 |
+
log_level = log_level,
|
| 261 |
+
log_level_replica = log_level_replica,
|
| 262 |
+
log_on_each_node = log_on_each_node,
|
| 263 |
+
logging_dir = logging_dir,
|
| 264 |
+
logging_strategy = logging_strategy,
|
| 265 |
+
logging_first_step = logging_first_step,
|
| 266 |
+
logging_steps = logging_steps,
|
| 267 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 268 |
+
save_strategy = save_strategy,
|
| 269 |
+
save_steps = save_steps,
|
| 270 |
+
save_total_limit = save_total_limit,
|
| 271 |
+
save_safetensors = save_safetensors,
|
| 272 |
+
save_on_each_node = save_on_each_node,
|
| 273 |
+
save_only_model = save_only_model,
|
| 274 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 275 |
+
no_cuda = no_cuda,
|
| 276 |
+
use_cpu = use_cpu,
|
| 277 |
+
use_mps_device = use_mps_device,
|
| 278 |
+
seed = seed,
|
| 279 |
+
data_seed = data_seed,
|
| 280 |
+
jit_mode_eval = jit_mode_eval,
|
| 281 |
+
use_ipex = use_ipex,
|
| 282 |
+
bf16 = bf16,
|
| 283 |
+
fp16 = fp16,
|
| 284 |
+
fp16_opt_level = fp16_opt_level,
|
| 285 |
+
half_precision_backend = half_precision_backend,
|
| 286 |
+
bf16_full_eval = bf16_full_eval,
|
| 287 |
+
fp16_full_eval = fp16_full_eval,
|
| 288 |
+
tf32 = tf32,
|
| 289 |
+
local_rank = local_rank,
|
| 290 |
+
ddp_backend = ddp_backend,
|
| 291 |
+
tpu_num_cores = tpu_num_cores,
|
| 292 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 293 |
+
debug = debug,
|
| 294 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 295 |
+
eval_steps = eval_steps,
|
| 296 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 297 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 298 |
+
past_index = past_index,
|
| 299 |
+
run_name = run_name,
|
| 300 |
+
disable_tqdm = disable_tqdm,
|
| 301 |
+
remove_unused_columns = remove_unused_columns,
|
| 302 |
+
label_names = label_names,
|
| 303 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 304 |
+
metric_for_best_model = metric_for_best_model,
|
| 305 |
+
greater_is_better = greater_is_better,
|
| 306 |
+
ignore_data_skip = ignore_data_skip,
|
| 307 |
+
fsdp = fsdp,
|
| 308 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 309 |
+
fsdp_config = fsdp_config,
|
| 310 |
+
tp_size = tp_size,
|
| 311 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 312 |
+
accelerator_config = accelerator_config,
|
| 313 |
+
deepspeed = deepspeed,
|
| 314 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 315 |
+
optim = optim,
|
| 316 |
+
optim_args = optim_args,
|
| 317 |
+
adafactor = adafactor,
|
| 318 |
+
group_by_length = group_by_length,
|
| 319 |
+
length_column_name = length_column_name,
|
| 320 |
+
report_to = report_to,
|
| 321 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 322 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 323 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 324 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 325 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 326 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 327 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 328 |
+
push_to_hub = push_to_hub,
|
| 329 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 330 |
+
hub_model_id = hub_model_id,
|
| 331 |
+
hub_strategy = hub_strategy,
|
| 332 |
+
hub_token = hub_token,
|
| 333 |
+
hub_private_repo = hub_private_repo,
|
| 334 |
+
hub_always_push = hub_always_push,
|
| 335 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 336 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 337 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 338 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 339 |
+
fp16_backend = fp16_backend,
|
| 340 |
+
evaluation_strategy = evaluation_strategy,
|
| 341 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 342 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 343 |
+
push_to_hub_token = push_to_hub_token,
|
| 344 |
+
mp_parameters = mp_parameters,
|
| 345 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 346 |
+
full_determinism = full_determinism,
|
| 347 |
+
torchdynamo = torchdynamo,
|
| 348 |
+
ray_scope = ray_scope,
|
| 349 |
+
ddp_timeout = ddp_timeout,
|
| 350 |
+
torch_compile = torch_compile,
|
| 351 |
+
torch_compile_backend = torch_compile_backend,
|
| 352 |
+
torch_compile_mode = torch_compile_mode,
|
| 353 |
+
dispatch_batches = dispatch_batches,
|
| 354 |
+
split_batches = split_batches,
|
| 355 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 356 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 357 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 358 |
+
optim_target_modules = optim_target_modules,
|
| 359 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 360 |
+
eval_on_start = eval_on_start,
|
| 361 |
+
use_liger_kernel = use_liger_kernel,
|
| 362 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 363 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 364 |
+
max_length = max_length,
|
| 365 |
+
max_prompt_length = max_prompt_length,
|
| 366 |
+
max_completion_length = max_completion_length,
|
| 367 |
+
disable_dropout = disable_dropout,
|
| 368 |
+
step_separator = step_separator,
|
| 369 |
+
train_on_last_step_only = train_on_last_step_only,
|
| 370 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 371 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 372 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
class _UnslothPRMTrainer(Trainer):
|
| 376 |
+
""""""
|
| 377 |
+
|
| 378 |
+
_tag_names = ["trl", "prm"]
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 383 |
+
args: Optional[PRMConfig] = None,
|
| 384 |
+
data_collator: Optional[DataCollator] = None,
|
| 385 |
+
train_dataset: Optional[Dataset] = None,
|
| 386 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 387 |
+
processing_class: Optional[
|
| 388 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 389 |
+
] = None,
|
| 390 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 391 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 392 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 393 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
| 394 |
+
None,
|
| 395 |
+
None,
|
| 396 |
+
),
|
| 397 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 398 |
+
peft_config: Optional[dict] = None,
|
| 399 |
+
):
|
| 400 |
+
if not is_peft_available() and peft_config is not None:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 403 |
+
)
|
| 404 |
+
elif is_peft_available() and peft_config is not None:
|
| 405 |
+
if not isinstance(model, PeftModel):
|
| 406 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
| 407 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
| 408 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 412 |
+
|
| 413 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 414 |
+
warnings.warn(
|
| 415 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
| 416 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
| 417 |
+
)
|
| 418 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 419 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 420 |
+
|
| 421 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 422 |
+
|
| 423 |
+
model = model
|
| 424 |
+
|
| 425 |
+
# Disable dropout in the model
|
| 426 |
+
if args.disable_dropout:
|
| 427 |
+
disable_dropout_in_model(model)
|
| 428 |
+
|
| 429 |
+
if compute_metrics is None:
|
| 430 |
+
compute_metrics = compute_accuracy
|
| 431 |
+
|
| 432 |
+
if data_collator is None:
|
| 433 |
+
if processing_class is None:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
| 436 |
+
)
|
| 437 |
+
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
| 438 |
+
|
| 439 |
+
if "input_ids" not in train_dataset.column_names:
|
| 440 |
+
with PartialState().local_main_process_first():
|
| 441 |
+
fn_kwargs = {
|
| 442 |
+
"tokenizer": processing_class,
|
| 443 |
+
"step_separator": args.step_separator,
|
| 444 |
+
"max_length": args.max_length,
|
| 445 |
+
"max_prompt_length": args.max_prompt_length,
|
| 446 |
+
"max_completion_length": args.max_completion_length,
|
| 447 |
+
"train_on_last_step_only": args.train_on_last_step_only,
|
| 448 |
+
}
|
| 449 |
+
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
| 450 |
+
train_dataset = train_dataset.map(
|
| 451 |
+
self.tokenize_row,
|
| 452 |
+
fn_kwargs=train_fn_kwargs,
|
| 453 |
+
num_proc=args.dataset_num_proc,
|
| 454 |
+
remove_columns=train_dataset.features,
|
| 455 |
+
desc="Tokenizing train dataset",
|
| 456 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
| 457 |
+
{
|
| 458 |
+
"labels": features.Sequence(features.Value("int64")),
|
| 459 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
| 460 |
+
}
|
| 461 |
+
),
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
| 465 |
+
if eval_dataset is not None:
|
| 466 |
+
eval_dataset = eval_dataset.map(
|
| 467 |
+
self.tokenize_row,
|
| 468 |
+
fn_kwargs=eval_fn_kwargs,
|
| 469 |
+
num_proc=args.dataset_num_proc,
|
| 470 |
+
remove_columns=eval_dataset.features,
|
| 471 |
+
desc="Tokenizing eval dataset",
|
| 472 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
| 473 |
+
{
|
| 474 |
+
"labels": features.Sequence(features.Value("int64")),
|
| 475 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
| 476 |
+
}
|
| 477 |
+
),
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
super().__init__(
|
| 481 |
+
model=model,
|
| 482 |
+
args=args,
|
| 483 |
+
data_collator=data_collator,
|
| 484 |
+
train_dataset=train_dataset,
|
| 485 |
+
eval_dataset=eval_dataset,
|
| 486 |
+
processing_class=processing_class,
|
| 487 |
+
model_init=model_init,
|
| 488 |
+
compute_metrics=compute_metrics,
|
| 489 |
+
callbacks=callbacks,
|
| 490 |
+
optimizers=optimizers,
|
| 491 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 495 |
+
if hasattr(self.model, "add_model_tags"):
|
| 496 |
+
self.model.add_model_tags(self._tag_names)
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
def tokenize_row(
|
| 500 |
+
features,
|
| 501 |
+
tokenizer,
|
| 502 |
+
step_separator,
|
| 503 |
+
max_length,
|
| 504 |
+
max_prompt_length,
|
| 505 |
+
max_completion_length,
|
| 506 |
+
train_on_last_step_only,
|
| 507 |
+
is_eval,
|
| 508 |
+
):
|
| 509 |
+
r"""
|
| 510 |
+
Tokenize a row of the dataset.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
features (`dict[str, str]`):
|
| 514 |
+
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
| 515 |
+
tokenizer (`PreTrainedTokenizerBase`):
|
| 516 |
+
Tokenizer used to process the data.
|
| 517 |
+
step_separator (`str`):
|
| 518 |
+
Separator between steps in the completion.
|
| 519 |
+
max_length (`int` or `None`):
|
| 520 |
+
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
| 521 |
+
max_prompt_length (`int` or `None`):
|
| 522 |
+
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
| 523 |
+
max_completion_length (`int` or `None`):
|
| 524 |
+
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
| 525 |
+
train_on_last_step_only (`bool`):
|
| 526 |
+
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
| 527 |
+
token of the completion.
|
| 528 |
+
is_eval (`bool`):
|
| 529 |
+
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
`dict[str, list[int]]`:
|
| 533 |
+
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
| 534 |
+
|
| 535 |
+
Example:
|
| 536 |
+
```python
|
| 537 |
+
>>> from transformers import AutoTokenizer
|
| 538 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 539 |
+
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
| 540 |
+
... "completions": ["11 is greater than 8.",
|
| 541 |
+
... "Hence, 9.11 > 9.8."],
|
| 542 |
+
... "labels": [True, False]}
|
| 543 |
+
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
| 544 |
+
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
| 545 |
+
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
| 546 |
+
```
|
| 547 |
+
"""
|
| 548 |
+
# Tokenize the prompt and completions
|
| 549 |
+
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
| 550 |
+
completions_ids = [
|
| 551 |
+
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
| 552 |
+
]
|
| 553 |
+
if train_on_last_step_only and not is_eval:
|
| 554 |
+
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
| 555 |
+
else:
|
| 556 |
+
labels = [int(label) for label in features["labels"]]
|
| 557 |
+
|
| 558 |
+
# Get the ID of the separator token and add it to the completions
|
| 559 |
+
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
| 560 |
+
completions_ids = [completion + separator_ids for completion in completions_ids]
|
| 561 |
+
|
| 562 |
+
# Create the label
|
| 563 |
+
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
| 564 |
+
|
| 565 |
+
# Join the completions and labels steps
|
| 566 |
+
completion_ids = list(chain(*completions_ids))
|
| 567 |
+
labels = list(chain(*labels))
|
| 568 |
+
|
| 569 |
+
if tokenizer.bos_token_id is not None:
|
| 570 |
+
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
| 571 |
+
|
| 572 |
+
# Truncate prompt and completion sequences
|
| 573 |
+
if max_prompt_length is not None:
|
| 574 |
+
prompt_ids = prompt_ids[-max_prompt_length:]
|
| 575 |
+
if max_completion_length is not None:
|
| 576 |
+
completion_ids = completion_ids[:max_completion_length]
|
| 577 |
+
labels = labels[:max_completion_length]
|
| 578 |
+
|
| 579 |
+
input_ids = prompt_ids + completion_ids
|
| 580 |
+
labels = [-100] * len(prompt_ids) + labels
|
| 581 |
+
|
| 582 |
+
if max_length is not None:
|
| 583 |
+
input_ids = input_ids[:max_length]
|
| 584 |
+
labels = labels[:max_length]
|
| 585 |
+
|
| 586 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 587 |
+
|
| 588 |
+
def create_model_card(
|
| 589 |
+
self,
|
| 590 |
+
model_name: Optional[str] = None,
|
| 591 |
+
dataset_name: Optional[str] = None,
|
| 592 |
+
tags: Union[str, list[str], None] = None,
|
| 593 |
+
):
|
| 594 |
+
"""
|
| 595 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 599 |
+
Name of the model.
|
| 600 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 601 |
+
Name of the dataset used for training.
|
| 602 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 603 |
+
Tags to be associated with the model card.
|
| 604 |
+
"""
|
| 605 |
+
if not self.is_world_process_zero():
|
| 606 |
+
return
|
| 607 |
+
|
| 608 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 609 |
+
base_model = self.model.config._name_or_path
|
| 610 |
+
else:
|
| 611 |
+
base_model = None
|
| 612 |
+
|
| 613 |
+
tags = tags or []
|
| 614 |
+
if isinstance(tags, str):
|
| 615 |
+
tags = [tags]
|
| 616 |
+
|
| 617 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 618 |
+
tags.append("unsloth")
|
| 619 |
+
|
| 620 |
+
citation = textwrap.dedent("""\
|
| 621 |
+
@article{uesato2022solving,
|
| 622 |
+
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
| 623 |
+
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
| 624 |
+
year = 2022,
|
| 625 |
+
journal = {arXiv preprint arXiv:2211.14275}
|
| 626 |
+
}""")
|
| 627 |
+
|
| 628 |
+
model_card = generate_model_card(
|
| 629 |
+
base_model=base_model,
|
| 630 |
+
model_name=model_name,
|
| 631 |
+
hub_model_id=self.hub_model_id,
|
| 632 |
+
dataset_name=dataset_name,
|
| 633 |
+
tags=tags,
|
| 634 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 635 |
+
trainer_name="PRM",
|
| 636 |
+
trainer_citation=citation,
|
| 637 |
+
paper_title="Solving math word problems with process-and outcome-based feedback",
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 641 |
+
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
Initialize PRMTrainer.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
model (`transformers.PreTrainedModel`):
|
| 648 |
+
The model to train, preferably an `AutoModelForTokenClassification`.
|
| 649 |
+
args (`PRMConfig`):
|
| 650 |
+
The arguments to use for training.
|
| 651 |
+
data_collator (`transformers.DataCollator`):
|
| 652 |
+
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
| 653 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 654 |
+
train_dataset (`datasets.Dataset`):
|
| 655 |
+
The dataset to use for training.
|
| 656 |
+
eval_dataset (`datasets.Dataset`):
|
| 657 |
+
The dataset to use for evaluation.
|
| 658 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 659 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 660 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 661 |
+
reuse the fine-tuned model.
|
| 662 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 663 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 664 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
| 665 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
| 666 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 667 |
+
The callbacks to use for training.
|
| 668 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 669 |
+
The optimizer and scheduler to use for training.
|
| 670 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 671 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 672 |
+
peft_config (`dict`, defaults to `None`):
|
| 673 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 674 |
+
|
| 675 |
+
"""
|
| 676 |
+
def __init__(
|
| 677 |
+
self,
|
| 678 |
+
model = None,
|
| 679 |
+
args = None,
|
| 680 |
+
data_collator = None,
|
| 681 |
+
train_dataset = None,
|
| 682 |
+
eval_dataset = None,
|
| 683 |
+
processing_class = None,
|
| 684 |
+
model_init = None,
|
| 685 |
+
compute_metrics = None,
|
| 686 |
+
callbacks = None,
|
| 687 |
+
preprocess_logits_for_metrics = None,
|
| 688 |
+
peft_config = None,
|
| 689 |
+
**kwargs
|
| 690 |
+
):
|
| 691 |
+
if args is None: args = UnslothPRMConfig()
|
| 692 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 693 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 694 |
+
force_float32 = False
|
| 695 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 696 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 697 |
+
force_float32 = True
|
| 698 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 699 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 700 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 701 |
+
from unsloth_zoo.utils import _get_dtype
|
| 702 |
+
dtype = _get_dtype(dtype)
|
| 703 |
+
float16 = dtype == torch.float16
|
| 704 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 705 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 706 |
+
if force_float32:
|
| 707 |
+
args.fp16 = False
|
| 708 |
+
args.bf16 = False
|
| 709 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 710 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 711 |
+
args.fp16 = float16
|
| 712 |
+
args.bf16 = not float16
|
| 713 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 714 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 715 |
+
args.eval_strategy = 'steps'
|
| 716 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 717 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 718 |
+
if ga_steps is not None and ga_steps > 1:
|
| 719 |
+
from transformers import __version__ as transformers_version
|
| 720 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 721 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 722 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 723 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 724 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 725 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 726 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 727 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 728 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 729 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 730 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 731 |
+
if force_float32:
|
| 732 |
+
args.bf16_full_eval = False
|
| 733 |
+
args.fp16_full_eval = False
|
| 734 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 735 |
+
args.bf16_full_eval = True
|
| 736 |
+
args.fp16_full_eval = False
|
| 737 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 738 |
+
args.bf16_full_eval = args.bf16
|
| 739 |
+
args.fp16_full_eval = args.fp16
|
| 740 |
+
_output_logits = False
|
| 741 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 742 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 743 |
+
if _output_logits:
|
| 744 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 745 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 746 |
+
pass
|
| 747 |
+
else:
|
| 748 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 749 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 750 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 751 |
+
max_seq_length = model.max_seq_length
|
| 752 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 753 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 754 |
+
model.for_training()
|
| 755 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 756 |
+
if 'processing_class' in locals():
|
| 757 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 758 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 759 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 760 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 761 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 762 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 763 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 764 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 765 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 766 |
+
else:
|
| 767 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 768 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 769 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 770 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 771 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 772 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 773 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 774 |
+
else:
|
| 775 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 776 |
+
other_metrics = []
|
| 777 |
+
|
| 778 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 779 |
+
PatchRLStatistics('prm_trainer', other_metrics)
|
| 780 |
+
|
| 781 |
+
super().__init__(
|
| 782 |
+
model = model,
|
| 783 |
+
args = args,
|
| 784 |
+
data_collator = data_collator,
|
| 785 |
+
train_dataset = train_dataset,
|
| 786 |
+
eval_dataset = eval_dataset,
|
| 787 |
+
processing_class = processing_class,
|
| 788 |
+
model_init = model_init,
|
| 789 |
+
compute_metrics = compute_metrics,
|
| 790 |
+
callbacks = callbacks,
|
| 791 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 792 |
+
peft_config = peft_config,**kwargs)
|
| 793 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 794 |
+
self.neftune_hook_handle.remove()
|
| 795 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 796 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 797 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 798 |
+
pass
|
| 799 |
+
|
| 800 |
+
pass
|
unsloth_compiled_cache/UnslothRLOOTrainer.py
ADDED
|
@@ -0,0 +1,1133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothRLOOConfig(RLOOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`RLOOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
|
| 54 |
+
Name of this experiment.
|
| 55 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
| 56 |
+
Path to the reward model.
|
| 57 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
| 58 |
+
Number of epochs to train.
|
| 59 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
| 60 |
+
Whether to whiten the rewards.
|
| 61 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
| 62 |
+
KL coefficient.
|
| 63 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
| 64 |
+
Clip range.
|
| 65 |
+
rloo_k (`int`, *optional*, defaults to `2`):
|
| 66 |
+
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
|
| 67 |
+
normalize_reward (`bool`, *optional*, defaults to `False`):
|
| 68 |
+
Whether to normalize rewards.
|
| 69 |
+
reward_clip_range (`float`, *optional*, defaults to `10.0`):
|
| 70 |
+
Clip range for rewards.
|
| 71 |
+
normalize_advantage (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to normalize advantages.
|
| 73 |
+
token_level_kl (`bool`, *optional*, defaults to `True`):
|
| 74 |
+
Whether to use token-level KL penalty or sequence-level KL penalty.
|
| 75 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 77 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 78 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 82 |
+
default = None,
|
| 83 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 84 |
+
)
|
| 85 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 86 |
+
default = -1,
|
| 87 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 88 |
+
)
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
output_dir = None,
|
| 92 |
+
overwrite_output_dir = None,
|
| 93 |
+
do_train = False,
|
| 94 |
+
do_eval = False,
|
| 95 |
+
do_predict = False,
|
| 96 |
+
eval_strategy = 'no',
|
| 97 |
+
prediction_loss_only = False,
|
| 98 |
+
per_device_train_batch_size = 4,
|
| 99 |
+
per_device_eval_batch_size = 4,
|
| 100 |
+
per_gpu_train_batch_size = None,
|
| 101 |
+
per_gpu_eval_batch_size = None,
|
| 102 |
+
gradient_accumulation_steps = 2,
|
| 103 |
+
eval_accumulation_steps = 2,
|
| 104 |
+
eval_delay = 0,
|
| 105 |
+
torch_empty_cache_steps = 250,
|
| 106 |
+
learning_rate = 5e-05,
|
| 107 |
+
weight_decay = 0.01,
|
| 108 |
+
adam_beta1 = 0.9,
|
| 109 |
+
adam_beta2 = 0.999,
|
| 110 |
+
adam_epsilon = 1e-08,
|
| 111 |
+
max_grad_norm = 1.0,
|
| 112 |
+
num_train_epochs = 3.0,
|
| 113 |
+
max_steps = -1,
|
| 114 |
+
lr_scheduler_type = 'linear',
|
| 115 |
+
warmup_ratio = 0.1,
|
| 116 |
+
warmup_steps = 0,
|
| 117 |
+
log_level = 'passive',
|
| 118 |
+
log_level_replica = 'warning',
|
| 119 |
+
log_on_each_node = True,
|
| 120 |
+
logging_dir = None,
|
| 121 |
+
logging_strategy = 'steps',
|
| 122 |
+
logging_first_step = False,
|
| 123 |
+
logging_steps = 1,
|
| 124 |
+
logging_nan_inf_filter = False,
|
| 125 |
+
save_strategy = 'steps',
|
| 126 |
+
save_steps = 500,
|
| 127 |
+
save_total_limit = None,
|
| 128 |
+
save_safetensors = True,
|
| 129 |
+
save_on_each_node = False,
|
| 130 |
+
save_only_model = False,
|
| 131 |
+
restore_callback_states_from_checkpoint = False,
|
| 132 |
+
no_cuda = False,
|
| 133 |
+
use_cpu = False,
|
| 134 |
+
use_mps_device = False,
|
| 135 |
+
seed = 3407,
|
| 136 |
+
data_seed = 3407,
|
| 137 |
+
jit_mode_eval = False,
|
| 138 |
+
use_ipex = False,
|
| 139 |
+
bf16 = False,
|
| 140 |
+
fp16 = False,
|
| 141 |
+
fp16_opt_level = 'O1',
|
| 142 |
+
half_precision_backend = 'auto',
|
| 143 |
+
bf16_full_eval = False,
|
| 144 |
+
fp16_full_eval = False,
|
| 145 |
+
tf32 = None,
|
| 146 |
+
local_rank = -1,
|
| 147 |
+
ddp_backend = None,
|
| 148 |
+
tpu_num_cores = None,
|
| 149 |
+
tpu_metrics_debug = False,
|
| 150 |
+
debug = '',
|
| 151 |
+
dataloader_drop_last = False,
|
| 152 |
+
eval_steps = None,
|
| 153 |
+
dataloader_num_workers = 0,
|
| 154 |
+
dataloader_prefetch_factor = None,
|
| 155 |
+
past_index = -1,
|
| 156 |
+
run_name = None,
|
| 157 |
+
disable_tqdm = None,
|
| 158 |
+
remove_unused_columns = True,
|
| 159 |
+
label_names = None,
|
| 160 |
+
load_best_model_at_end = False,
|
| 161 |
+
metric_for_best_model = None,
|
| 162 |
+
greater_is_better = None,
|
| 163 |
+
ignore_data_skip = False,
|
| 164 |
+
fsdp = '',
|
| 165 |
+
fsdp_min_num_params = 0,
|
| 166 |
+
fsdp_config = None,
|
| 167 |
+
tp_size = 0,
|
| 168 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 169 |
+
accelerator_config = None,
|
| 170 |
+
deepspeed = None,
|
| 171 |
+
label_smoothing_factor = 0.0,
|
| 172 |
+
optim = 'adamw_8bit',
|
| 173 |
+
optim_args = None,
|
| 174 |
+
adafactor = False,
|
| 175 |
+
group_by_length = False,
|
| 176 |
+
length_column_name = 'length',
|
| 177 |
+
report_to = None,
|
| 178 |
+
ddp_find_unused_parameters = None,
|
| 179 |
+
ddp_bucket_cap_mb = None,
|
| 180 |
+
ddp_broadcast_buffers = None,
|
| 181 |
+
dataloader_pin_memory = True,
|
| 182 |
+
dataloader_persistent_workers = False,
|
| 183 |
+
skip_memory_metrics = True,
|
| 184 |
+
use_legacy_prediction_loop = False,
|
| 185 |
+
push_to_hub = False,
|
| 186 |
+
resume_from_checkpoint = None,
|
| 187 |
+
hub_model_id = None,
|
| 188 |
+
hub_strategy = 'every_save',
|
| 189 |
+
hub_token = None,
|
| 190 |
+
hub_private_repo = None,
|
| 191 |
+
hub_always_push = False,
|
| 192 |
+
gradient_checkpointing = False,
|
| 193 |
+
gradient_checkpointing_kwargs = None,
|
| 194 |
+
include_inputs_for_metrics = False,
|
| 195 |
+
eval_do_concat_batches = True,
|
| 196 |
+
fp16_backend = 'auto',
|
| 197 |
+
evaluation_strategy = None,
|
| 198 |
+
push_to_hub_model_id = None,
|
| 199 |
+
push_to_hub_organization = None,
|
| 200 |
+
push_to_hub_token = None,
|
| 201 |
+
mp_parameters = '',
|
| 202 |
+
auto_find_batch_size = False,
|
| 203 |
+
full_determinism = False,
|
| 204 |
+
torchdynamo = None,
|
| 205 |
+
ray_scope = 'last',
|
| 206 |
+
ddp_timeout = 1800,
|
| 207 |
+
torch_compile = False,
|
| 208 |
+
torch_compile_backend = None,
|
| 209 |
+
torch_compile_mode = None,
|
| 210 |
+
dispatch_batches = None,
|
| 211 |
+
split_batches = None,
|
| 212 |
+
include_tokens_per_second = False,
|
| 213 |
+
include_num_input_tokens_seen = False,
|
| 214 |
+
neftune_noise_alpha = None,
|
| 215 |
+
optim_target_modules = None,
|
| 216 |
+
batch_eval_metrics = False,
|
| 217 |
+
eval_on_start = False,
|
| 218 |
+
use_liger_kernel = False,
|
| 219 |
+
eval_use_gather_object = False,
|
| 220 |
+
average_tokens_across_devices = False,
|
| 221 |
+
dataset_num_proc = None,
|
| 222 |
+
num_mini_batches = 1,
|
| 223 |
+
total_episodes = None,
|
| 224 |
+
local_rollout_forward_batch_size = 64,
|
| 225 |
+
num_sample_generations = 10,
|
| 226 |
+
response_length = 53,
|
| 227 |
+
stop_token = None,
|
| 228 |
+
stop_token_id = None,
|
| 229 |
+
temperature = 0.7,
|
| 230 |
+
missing_eos_penalty = None,
|
| 231 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
| 232 |
+
world_size = None,
|
| 233 |
+
num_total_batches = None,
|
| 234 |
+
micro_batch_size = None,
|
| 235 |
+
local_batch_size = None,
|
| 236 |
+
batch_size = None,
|
| 237 |
+
local_mini_batch_size = None,
|
| 238 |
+
mini_batch_size = None,
|
| 239 |
+
exp_name = 'rloo_config',
|
| 240 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
| 241 |
+
num_ppo_epochs = 4,
|
| 242 |
+
whiten_rewards = False,
|
| 243 |
+
kl_coef = 0.05,
|
| 244 |
+
cliprange = 0.2,
|
| 245 |
+
rloo_k = 2,
|
| 246 |
+
normalize_reward = False,
|
| 247 |
+
reward_clip_range = 10.0,
|
| 248 |
+
normalize_advantage = False,
|
| 249 |
+
token_level_kl = False,
|
| 250 |
+
ds3_gather_for_generation = True,
|
| 251 |
+
vllm_sampling_params = None,
|
| 252 |
+
unsloth_num_chunks = -1,
|
| 253 |
+
**kwargs,
|
| 254 |
+
):
|
| 255 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 256 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 257 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 258 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 259 |
+
save_strategy = 'no'
|
| 260 |
+
if dataset_num_proc is None:
|
| 261 |
+
from multiprocessing import cpu_count
|
| 262 |
+
dataset_num_proc = cpu_count()
|
| 263 |
+
|
| 264 |
+
super().__init__(
|
| 265 |
+
output_dir = output_dir,
|
| 266 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 267 |
+
do_train = do_train,
|
| 268 |
+
do_eval = do_eval,
|
| 269 |
+
do_predict = do_predict,
|
| 270 |
+
eval_strategy = eval_strategy,
|
| 271 |
+
prediction_loss_only = prediction_loss_only,
|
| 272 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 273 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 274 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 275 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 276 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 277 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 278 |
+
eval_delay = eval_delay,
|
| 279 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 280 |
+
learning_rate = learning_rate,
|
| 281 |
+
weight_decay = weight_decay,
|
| 282 |
+
adam_beta1 = adam_beta1,
|
| 283 |
+
adam_beta2 = adam_beta2,
|
| 284 |
+
adam_epsilon = adam_epsilon,
|
| 285 |
+
max_grad_norm = max_grad_norm,
|
| 286 |
+
num_train_epochs = num_train_epochs,
|
| 287 |
+
max_steps = max_steps,
|
| 288 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 289 |
+
warmup_ratio = warmup_ratio,
|
| 290 |
+
warmup_steps = warmup_steps,
|
| 291 |
+
log_level = log_level,
|
| 292 |
+
log_level_replica = log_level_replica,
|
| 293 |
+
log_on_each_node = log_on_each_node,
|
| 294 |
+
logging_dir = logging_dir,
|
| 295 |
+
logging_strategy = logging_strategy,
|
| 296 |
+
logging_first_step = logging_first_step,
|
| 297 |
+
logging_steps = logging_steps,
|
| 298 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 299 |
+
save_strategy = save_strategy,
|
| 300 |
+
save_steps = save_steps,
|
| 301 |
+
save_total_limit = save_total_limit,
|
| 302 |
+
save_safetensors = save_safetensors,
|
| 303 |
+
save_on_each_node = save_on_each_node,
|
| 304 |
+
save_only_model = save_only_model,
|
| 305 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 306 |
+
no_cuda = no_cuda,
|
| 307 |
+
use_cpu = use_cpu,
|
| 308 |
+
use_mps_device = use_mps_device,
|
| 309 |
+
seed = seed,
|
| 310 |
+
data_seed = data_seed,
|
| 311 |
+
jit_mode_eval = jit_mode_eval,
|
| 312 |
+
use_ipex = use_ipex,
|
| 313 |
+
bf16 = bf16,
|
| 314 |
+
fp16 = fp16,
|
| 315 |
+
fp16_opt_level = fp16_opt_level,
|
| 316 |
+
half_precision_backend = half_precision_backend,
|
| 317 |
+
bf16_full_eval = bf16_full_eval,
|
| 318 |
+
fp16_full_eval = fp16_full_eval,
|
| 319 |
+
tf32 = tf32,
|
| 320 |
+
local_rank = local_rank,
|
| 321 |
+
ddp_backend = ddp_backend,
|
| 322 |
+
tpu_num_cores = tpu_num_cores,
|
| 323 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 324 |
+
debug = debug,
|
| 325 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 326 |
+
eval_steps = eval_steps,
|
| 327 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 328 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 329 |
+
past_index = past_index,
|
| 330 |
+
run_name = run_name,
|
| 331 |
+
disable_tqdm = disable_tqdm,
|
| 332 |
+
remove_unused_columns = remove_unused_columns,
|
| 333 |
+
label_names = label_names,
|
| 334 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 335 |
+
metric_for_best_model = metric_for_best_model,
|
| 336 |
+
greater_is_better = greater_is_better,
|
| 337 |
+
ignore_data_skip = ignore_data_skip,
|
| 338 |
+
fsdp = fsdp,
|
| 339 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 340 |
+
fsdp_config = fsdp_config,
|
| 341 |
+
tp_size = tp_size,
|
| 342 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 343 |
+
accelerator_config = accelerator_config,
|
| 344 |
+
deepspeed = deepspeed,
|
| 345 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 346 |
+
optim = optim,
|
| 347 |
+
optim_args = optim_args,
|
| 348 |
+
adafactor = adafactor,
|
| 349 |
+
group_by_length = group_by_length,
|
| 350 |
+
length_column_name = length_column_name,
|
| 351 |
+
report_to = report_to,
|
| 352 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 353 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 354 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 355 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 356 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 357 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 358 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 359 |
+
push_to_hub = push_to_hub,
|
| 360 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 361 |
+
hub_model_id = hub_model_id,
|
| 362 |
+
hub_strategy = hub_strategy,
|
| 363 |
+
hub_token = hub_token,
|
| 364 |
+
hub_private_repo = hub_private_repo,
|
| 365 |
+
hub_always_push = hub_always_push,
|
| 366 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 367 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 368 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 369 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 370 |
+
fp16_backend = fp16_backend,
|
| 371 |
+
evaluation_strategy = evaluation_strategy,
|
| 372 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 373 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 374 |
+
push_to_hub_token = push_to_hub_token,
|
| 375 |
+
mp_parameters = mp_parameters,
|
| 376 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 377 |
+
full_determinism = full_determinism,
|
| 378 |
+
torchdynamo = torchdynamo,
|
| 379 |
+
ray_scope = ray_scope,
|
| 380 |
+
ddp_timeout = ddp_timeout,
|
| 381 |
+
torch_compile = torch_compile,
|
| 382 |
+
torch_compile_backend = torch_compile_backend,
|
| 383 |
+
torch_compile_mode = torch_compile_mode,
|
| 384 |
+
dispatch_batches = dispatch_batches,
|
| 385 |
+
split_batches = split_batches,
|
| 386 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 387 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 388 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 389 |
+
optim_target_modules = optim_target_modules,
|
| 390 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 391 |
+
eval_on_start = eval_on_start,
|
| 392 |
+
use_liger_kernel = use_liger_kernel,
|
| 393 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 394 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 395 |
+
dataset_num_proc = dataset_num_proc,
|
| 396 |
+
num_mini_batches = num_mini_batches,
|
| 397 |
+
total_episodes = total_episodes,
|
| 398 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
| 399 |
+
num_sample_generations = num_sample_generations,
|
| 400 |
+
response_length = response_length,
|
| 401 |
+
stop_token = stop_token,
|
| 402 |
+
stop_token_id = stop_token_id,
|
| 403 |
+
temperature = temperature,
|
| 404 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 405 |
+
sft_model_path = sft_model_path,
|
| 406 |
+
world_size = world_size,
|
| 407 |
+
num_total_batches = num_total_batches,
|
| 408 |
+
micro_batch_size = micro_batch_size,
|
| 409 |
+
local_batch_size = local_batch_size,
|
| 410 |
+
batch_size = batch_size,
|
| 411 |
+
local_mini_batch_size = local_mini_batch_size,
|
| 412 |
+
mini_batch_size = mini_batch_size,
|
| 413 |
+
exp_name = exp_name,
|
| 414 |
+
reward_model_path = reward_model_path,
|
| 415 |
+
num_ppo_epochs = num_ppo_epochs,
|
| 416 |
+
whiten_rewards = whiten_rewards,
|
| 417 |
+
kl_coef = kl_coef,
|
| 418 |
+
cliprange = cliprange,
|
| 419 |
+
rloo_k = rloo_k,
|
| 420 |
+
normalize_reward = normalize_reward,
|
| 421 |
+
reward_clip_range = reward_clip_range,
|
| 422 |
+
normalize_advantage = normalize_advantage,
|
| 423 |
+
token_level_kl = token_level_kl,
|
| 424 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 425 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 426 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 427 |
+
pass
|
| 428 |
+
|
| 429 |
+
class _UnslothRLOOTrainer(Trainer):
|
| 430 |
+
_tag_names = ["trl", "rloo"]
|
| 431 |
+
|
| 432 |
+
def __init__(
|
| 433 |
+
self,
|
| 434 |
+
config: RLOOConfig,
|
| 435 |
+
processing_class: Optional[
|
| 436 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 437 |
+
],
|
| 438 |
+
policy: nn.Module,
|
| 439 |
+
ref_policy: nn.Module,
|
| 440 |
+
reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
|
| 441 |
+
train_dataset: Dataset,
|
| 442 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
| 443 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 444 |
+
# less commonly used
|
| 445 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 446 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 447 |
+
) -> None:
|
| 448 |
+
if ref_policy is policy:
|
| 449 |
+
raise ValueError(
|
| 450 |
+
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
|
| 451 |
+
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
self.args = config
|
| 455 |
+
args = config
|
| 456 |
+
self.processing_class = processing_class
|
| 457 |
+
self.policy = policy
|
| 458 |
+
|
| 459 |
+
# Define the collator if not provided
|
| 460 |
+
if data_collator is None:
|
| 461 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
| 462 |
+
|
| 463 |
+
self.policy.generation_config.eos_token_id = (
|
| 464 |
+
None # disable `pad_token_id` and `eos_token_id` because we just want to
|
| 465 |
+
)
|
| 466 |
+
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
|
| 467 |
+
|
| 468 |
+
self.ref_policy = ref_policy
|
| 469 |
+
self.reward_model = reward_model
|
| 470 |
+
self.train_dataset = train_dataset
|
| 471 |
+
self.train_dataset_len = len(train_dataset)
|
| 472 |
+
self.data_collator = data_collator
|
| 473 |
+
self.eval_dataset = eval_dataset
|
| 474 |
+
self.optimizer, self.lr_scheduler = optimizers
|
| 475 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
| 476 |
+
|
| 477 |
+
#########
|
| 478 |
+
# calculate various batch sizes
|
| 479 |
+
#########
|
| 480 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
| 481 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
| 482 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
| 483 |
+
self.accelerator = accelerator
|
| 484 |
+
args.world_size = accelerator.num_processes
|
| 485 |
+
args.local_batch_size = (
|
| 486 |
+
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
| 487 |
+
)
|
| 488 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
| 489 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
| 490 |
+
args.mini_batch_size = exact_div(
|
| 491 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
| 492 |
+
)
|
| 493 |
+
args.local_mini_batch_size = exact_div(
|
| 494 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
| 495 |
+
)
|
| 496 |
+
args.num_total_batches = math.ceil(
|
| 497 |
+
args.total_episodes / args.batch_size
|
| 498 |
+
) # we may train for more than `total_episodes`
|
| 499 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
| 500 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
| 501 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
| 502 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
| 503 |
+
if args.num_sample_generations > 0:
|
| 504 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
| 505 |
+
self.local_dataloader_batch_size = exact_div(
|
| 506 |
+
args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
|
| 507 |
+
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
|
| 508 |
+
|
| 509 |
+
#########
|
| 510 |
+
# setup model, optimizer, and others
|
| 511 |
+
#########
|
| 512 |
+
for module in [policy, ref_policy, reward_model]:
|
| 513 |
+
if isinstance(module, nn.Module):
|
| 514 |
+
disable_dropout_in_model(module)
|
| 515 |
+
if args.stop_token and args.stop_token == "eos":
|
| 516 |
+
args.stop_token_id = self.processing_class.eos_token_id
|
| 517 |
+
self.model = policy
|
| 518 |
+
self.create_optimizer_and_scheduler(
|
| 519 |
+
num_training_steps=args.num_total_batches
|
| 520 |
+
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
| 521 |
+
|
| 522 |
+
#########
|
| 523 |
+
### trainer specifics
|
| 524 |
+
#########
|
| 525 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
| 526 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
| 527 |
+
self.callback_handler = CallbackHandler(
|
| 528 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
| 529 |
+
)
|
| 530 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
| 531 |
+
self.control = TrainerControl()
|
| 532 |
+
self.state = OnlineTrainerState(
|
| 533 |
+
is_local_process_zero=self.is_local_process_zero(),
|
| 534 |
+
is_world_process_zero=self.is_world_process_zero(),
|
| 535 |
+
stateful_callbacks=[
|
| 536 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
| 537 |
+
],
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.current_flos = 0
|
| 541 |
+
self.hp_search_backend = None
|
| 542 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 543 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 544 |
+
# Create distant repo and output directory if needed
|
| 545 |
+
self.hub_model_id = None
|
| 546 |
+
if self.args.push_to_hub:
|
| 547 |
+
self.init_hf_repo()
|
| 548 |
+
if self.args.should_save:
|
| 549 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 550 |
+
self.backup_model = None
|
| 551 |
+
|
| 552 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 553 |
+
if hasattr(self.model, "add_model_tags"):
|
| 554 |
+
self.model.add_model_tags(self._tag_names)
|
| 555 |
+
|
| 556 |
+
#########
|
| 557 |
+
### setup dataloader
|
| 558 |
+
#########
|
| 559 |
+
self.dataloader = DataLoader(
|
| 560 |
+
self.train_dataset,
|
| 561 |
+
batch_size=self.local_dataloader_batch_size,
|
| 562 |
+
shuffle=True,
|
| 563 |
+
collate_fn=self.data_collator,
|
| 564 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
| 565 |
+
)
|
| 566 |
+
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
| 567 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
| 568 |
+
torch.manual_seed(args.seed)
|
| 569 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
| 570 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
| 571 |
+
|
| 572 |
+
self.eval_dataloader = DataLoader(
|
| 573 |
+
self.eval_dataset,
|
| 574 |
+
batch_size=args.per_device_eval_batch_size,
|
| 575 |
+
collate_fn=self.data_collator,
|
| 576 |
+
drop_last=True,
|
| 577 |
+
) # no need to shuffle eval dataset
|
| 578 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
| 579 |
+
|
| 580 |
+
if self.is_deepspeed_enabled:
|
| 581 |
+
if isinstance(self.reward_model, nn.Module):
|
| 582 |
+
self.reward_model = prepare_deepspeed(
|
| 583 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 584 |
+
)
|
| 585 |
+
self.ref_policy = prepare_deepspeed(
|
| 586 |
+
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 587 |
+
)
|
| 588 |
+
self.deepspeed = self.model
|
| 589 |
+
else:
|
| 590 |
+
self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
| 591 |
+
if isinstance(self.reward_model, nn.Module):
|
| 592 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 593 |
+
|
| 594 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 595 |
+
return self.dataloader
|
| 596 |
+
|
| 597 |
+
def get_eval_dataloader(self) -> DataLoader:
|
| 598 |
+
return self.eval_dataloader
|
| 599 |
+
|
| 600 |
+
def train(self):
|
| 601 |
+
args = self.args
|
| 602 |
+
accelerator = self.accelerator
|
| 603 |
+
optimizer = self.optimizer
|
| 604 |
+
model = self.model
|
| 605 |
+
self.model_wrapped = self.model
|
| 606 |
+
ref_policy = self.ref_policy
|
| 607 |
+
reward_model = self.reward_model
|
| 608 |
+
processing_class = self.processing_class
|
| 609 |
+
dataloader = self.dataloader
|
| 610 |
+
device = accelerator.device
|
| 611 |
+
|
| 612 |
+
def repeat_generator():
|
| 613 |
+
while True:
|
| 614 |
+
yield from dataloader
|
| 615 |
+
|
| 616 |
+
iter_dataloader = iter(repeat_generator())
|
| 617 |
+
generation_config = GenerationConfig(
|
| 618 |
+
max_new_tokens=args.response_length,
|
| 619 |
+
temperature=(args.temperature + 1e-7),
|
| 620 |
+
top_k=0.0,
|
| 621 |
+
top_p=1.0,
|
| 622 |
+
do_sample=True,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
accelerator.print("===training policy===")
|
| 626 |
+
start_time = time.time()
|
| 627 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
| 628 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
| 629 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 630 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
| 631 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 632 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
| 633 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
| 634 |
+
model.train()
|
| 635 |
+
|
| 636 |
+
# trainer state initialization
|
| 637 |
+
self.state.global_step = 0
|
| 638 |
+
self.state.episode = 0
|
| 639 |
+
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
|
| 640 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
| 641 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
| 642 |
+
if args.logging_steps is not None:
|
| 643 |
+
if args.logging_steps < 1:
|
| 644 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
| 645 |
+
else:
|
| 646 |
+
self.state.logging_steps = args.logging_steps
|
| 647 |
+
if args.eval_steps is not None:
|
| 648 |
+
if args.eval_steps < 1:
|
| 649 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
| 650 |
+
else:
|
| 651 |
+
self.state.eval_steps = args.eval_steps
|
| 652 |
+
if args.save_steps is not None:
|
| 653 |
+
if args.save_steps < 1:
|
| 654 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
| 655 |
+
else:
|
| 656 |
+
self.state.save_steps = args.save_steps
|
| 657 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
| 658 |
+
|
| 659 |
+
for update in range(1, args.num_total_batches + 1):
|
| 660 |
+
self.state.episode += 1 * args.batch_size
|
| 661 |
+
data = next(iter_dataloader)
|
| 662 |
+
with torch.no_grad():
|
| 663 |
+
queries = data["input_ids"].to(device)
|
| 664 |
+
queries = queries.repeat(args.rloo_k, 1)
|
| 665 |
+
context_length = queries.shape[1]
|
| 666 |
+
responses = []
|
| 667 |
+
postprocessed_responses = []
|
| 668 |
+
logprobs = []
|
| 669 |
+
ref_logprobs = []
|
| 670 |
+
scores = []
|
| 671 |
+
sequence_lengths = []
|
| 672 |
+
|
| 673 |
+
# Generate responses and compute logprobs
|
| 674 |
+
with unwrap_model_for_generation(
|
| 675 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 676 |
+
) as unwrapped_model:
|
| 677 |
+
query_responses, logitss = batch_generation(
|
| 678 |
+
unwrapped_model,
|
| 679 |
+
queries,
|
| 680 |
+
args.local_rollout_forward_batch_size,
|
| 681 |
+
processing_class.pad_token_id,
|
| 682 |
+
generation_config,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Process responses in batches
|
| 686 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
| 687 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
| 688 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
| 689 |
+
response = query_response[:, context_length:]
|
| 690 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
| 691 |
+
logprob = selective_log_softmax(logits, response)
|
| 692 |
+
del logits
|
| 693 |
+
torch.cuda.empty_cache()
|
| 694 |
+
|
| 695 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
| 696 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
| 697 |
+
ref_logits /= args.temperature + 1e-7
|
| 698 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
| 699 |
+
del ref_output, ref_logits
|
| 700 |
+
torch.cuda.empty_cache()
|
| 701 |
+
|
| 702 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
| 703 |
+
postprocessed_response = response
|
| 704 |
+
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 705 |
+
postprocessed_response = truncate_response(
|
| 706 |
+
args.stop_token_id, processing_class.pad_token_id, response
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# Response Processing 2. run reward model on the truncated responses
|
| 710 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 711 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
| 712 |
+
|
| 713 |
+
if isinstance(reward_model, nn.Module):
|
| 714 |
+
_, score, _ = get_reward(
|
| 715 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 716 |
+
)
|
| 717 |
+
else:
|
| 718 |
+
score = torch.tensor(
|
| 719 |
+
reward_model(
|
| 720 |
+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
| 721 |
+
),
|
| 722 |
+
dtype=torch.float,
|
| 723 |
+
).to(device)
|
| 724 |
+
|
| 725 |
+
# Store batch results
|
| 726 |
+
responses.append(response)
|
| 727 |
+
postprocessed_responses.append(postprocessed_response)
|
| 728 |
+
logprobs.append(logprob)
|
| 729 |
+
ref_logprobs.append(ref_logprob)
|
| 730 |
+
sequence_lengths.append(sequence_length)
|
| 731 |
+
scores.append(score)
|
| 732 |
+
|
| 733 |
+
# Concatenate all batched results
|
| 734 |
+
responses = torch.cat(responses, 0)
|
| 735 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
| 736 |
+
logprobs = torch.cat(logprobs, 0)
|
| 737 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
| 738 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
| 739 |
+
scores = torch.cat(scores, 0)
|
| 740 |
+
del (logprob, ref_logprob, score)
|
| 741 |
+
torch.cuda.empty_cache()
|
| 742 |
+
gc.collect()
|
| 743 |
+
|
| 744 |
+
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
|
| 745 |
+
# responses not passing that filter will receive a low (fixed) score
|
| 746 |
+
# only query humans on responses that pass that filter
|
| 747 |
+
contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
|
| 748 |
+
if args.missing_eos_penalty is not None:
|
| 749 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 750 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
| 751 |
+
|
| 752 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
| 753 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
| 754 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
| 755 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
| 756 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
| 757 |
+
|
| 758 |
+
# 4. compute rewards
|
| 759 |
+
# Compute KL divergence
|
| 760 |
+
kl = logprobs - ref_logprobs
|
| 761 |
+
|
| 762 |
+
# Normalize rewards
|
| 763 |
+
if args.normalize_reward:
|
| 764 |
+
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
|
| 765 |
+
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
|
| 766 |
+
|
| 767 |
+
# Compute total reward with KL penalty
|
| 768 |
+
if args.token_level_kl:
|
| 769 |
+
# Token-level KL penalty: apply KL penalty per token
|
| 770 |
+
kl_reward = -args.kl_coef * kl
|
| 771 |
+
|
| 772 |
+
# Get the index of the last non-padded token for each sequence
|
| 773 |
+
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
|
| 774 |
+
last_reward = torch.zeros_like(kl)
|
| 775 |
+
# Ensure scores has correct shape and type
|
| 776 |
+
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
|
| 777 |
+
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
|
| 778 |
+
|
| 779 |
+
# Combine KL reward and last reward
|
| 780 |
+
non_score_reward = kl_reward.sum(1) # Keep this for logging
|
| 781 |
+
reward = last_reward + kl_reward
|
| 782 |
+
rlhf_reward = reward.sum(1) # Sum across sequence length
|
| 783 |
+
else:
|
| 784 |
+
# Sequence-level KL penalty: sum KL across tokens first
|
| 785 |
+
sequence_kl = kl.sum(1)
|
| 786 |
+
non_score_reward = -args.kl_coef * sequence_kl
|
| 787 |
+
rlhf_reward = non_score_reward + scores
|
| 788 |
+
|
| 789 |
+
# vectorized RLOO advantages implementation
|
| 790 |
+
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
|
| 791 |
+
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
|
| 792 |
+
advantages = rlhf_reward - baseline
|
| 793 |
+
advantages = advantages.flatten()
|
| 794 |
+
|
| 795 |
+
# Normalize advantages
|
| 796 |
+
if args.normalize_advantage:
|
| 797 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 798 |
+
|
| 799 |
+
torch.cuda.empty_cache()
|
| 800 |
+
|
| 801 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
| 802 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
| 803 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
| 804 |
+
minibatch_idx = 0
|
| 805 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
| 806 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
| 807 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
| 808 |
+
gradient_accumulation_idx = 0
|
| 809 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
| 810 |
+
with accelerator.accumulate(model):
|
| 811 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
| 812 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
| 813 |
+
|
| 814 |
+
# Get batch data
|
| 815 |
+
mb_advantage = advantages[micro_batch_inds]
|
| 816 |
+
mb_responses = responses[micro_batch_inds]
|
| 817 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
| 818 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
| 819 |
+
|
| 820 |
+
# Forward pass
|
| 821 |
+
output = forward(model, mb_query_responses, processing_class.pad_token_id)
|
| 822 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 823 |
+
logits /= args.temperature + 1e-7
|
| 824 |
+
|
| 825 |
+
# Compute new logprobs
|
| 826 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
| 827 |
+
new_logprobs = torch.masked_fill(
|
| 828 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Compute probability ratios
|
| 832 |
+
new_ratio = (new_logprobs - mb_logprobs).exp()
|
| 833 |
+
new_logprobs = new_logprobs.sum(1)
|
| 834 |
+
mb_logprobs = mb_logprobs.sum(1)
|
| 835 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
| 836 |
+
ratio = torch.exp(logprobs_diff)
|
| 837 |
+
|
| 838 |
+
# PPO clipped loss
|
| 839 |
+
pg_losses = -mb_advantage * ratio
|
| 840 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
| 841 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
| 842 |
+
pg_loss = pg_loss_max.mean()
|
| 843 |
+
|
| 844 |
+
# Final loss
|
| 845 |
+
loss = pg_loss
|
| 846 |
+
|
| 847 |
+
# Optimization step
|
| 848 |
+
accelerator.backward(loss)
|
| 849 |
+
optimizer.step()
|
| 850 |
+
optimizer.zero_grad()
|
| 851 |
+
|
| 852 |
+
with torch.no_grad():
|
| 853 |
+
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
|
| 854 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
| 855 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
| 856 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
| 857 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
| 858 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 859 |
+
pg_clipfrac
|
| 860 |
+
)
|
| 861 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
| 862 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
| 863 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
|
| 864 |
+
gradient_accumulation_idx += 1
|
| 865 |
+
minibatch_idx += 1
|
| 866 |
+
|
| 867 |
+
# del everything and empty cache
|
| 868 |
+
# fmt: off
|
| 869 |
+
del (
|
| 870 |
+
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
|
| 871 |
+
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
|
| 872 |
+
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
|
| 873 |
+
)
|
| 874 |
+
# fmt: on
|
| 875 |
+
torch.cuda.empty_cache()
|
| 876 |
+
|
| 877 |
+
# Compute metrics
|
| 878 |
+
with torch.no_grad():
|
| 879 |
+
mean_kl = kl.sum(1).mean()
|
| 880 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
| 881 |
+
mean_non_score_reward = non_score_reward.mean()
|
| 882 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
| 883 |
+
metrics = {}
|
| 884 |
+
metrics["eps"] = eps
|
| 885 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 886 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
| 887 |
+
metrics["objective/non_score_reward"] = (
|
| 888 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 889 |
+
)
|
| 890 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
| 891 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
| 892 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
| 893 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
| 894 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
| 895 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
| 896 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
| 897 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
| 898 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
| 899 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
| 900 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
| 901 |
+
metrics["episode"] = self.state.episode
|
| 902 |
+
self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
|
| 903 |
+
self.log(metrics)
|
| 904 |
+
del kl, mean_kl, mean_entropy, scores
|
| 905 |
+
|
| 906 |
+
self.lr_scheduler.step()
|
| 907 |
+
self.state.global_step += 1
|
| 908 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| 909 |
+
if self.control.should_save:
|
| 910 |
+
self._save_checkpoint(model, trial=None)
|
| 911 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 912 |
+
torch.cuda.empty_cache()
|
| 913 |
+
gc.collect()
|
| 914 |
+
|
| 915 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
| 916 |
+
self.generate_completions(sampling=True)
|
| 917 |
+
|
| 918 |
+
# HF trainer specifics
|
| 919 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| 920 |
+
if self.control.should_save:
|
| 921 |
+
self._save_checkpoint(model, trial=None, metrics=None)
|
| 922 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 923 |
+
|
| 924 |
+
def generate_completions(self, sampling: bool = False):
|
| 925 |
+
args = self.args
|
| 926 |
+
processing_class = self.processing_class
|
| 927 |
+
generation_config = GenerationConfig(
|
| 928 |
+
max_new_tokens=self.args.response_length,
|
| 929 |
+
temperature=(0.01 + 1e-7),
|
| 930 |
+
top_k=0.0,
|
| 931 |
+
top_p=1.0,
|
| 932 |
+
do_sample=True,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
table = defaultdict(list)
|
| 936 |
+
with unwrap_model_for_generation(
|
| 937 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 938 |
+
) as unwrapped_model:
|
| 939 |
+
for batch in self.eval_dataloader:
|
| 940 |
+
query = batch["input_ids"]
|
| 941 |
+
with torch.no_grad():
|
| 942 |
+
context_length = query.shape[1]
|
| 943 |
+
query_response, _ = batch_generation(
|
| 944 |
+
unwrapped_model,
|
| 945 |
+
query,
|
| 946 |
+
query.shape[0],
|
| 947 |
+
processing_class.pad_token_id,
|
| 948 |
+
generation_config,
|
| 949 |
+
)
|
| 950 |
+
response = query_response[:, context_length:]
|
| 951 |
+
postprocessed_response = response
|
| 952 |
+
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 953 |
+
postprocessed_response = truncate_response(
|
| 954 |
+
args.stop_token_id, processing_class.pad_token_id, response
|
| 955 |
+
)
|
| 956 |
+
table["query"].extend(
|
| 957 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
| 958 |
+
)
|
| 959 |
+
table["model response"].extend(
|
| 960 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 964 |
+
|
| 965 |
+
if isinstance(self.reward_model, nn.Module):
|
| 966 |
+
_, score, _ = get_reward(
|
| 967 |
+
self.reward_model,
|
| 968 |
+
postprocessed_query_response,
|
| 969 |
+
processing_class.pad_token_id,
|
| 970 |
+
context_length,
|
| 971 |
+
)
|
| 972 |
+
else:
|
| 973 |
+
score = torch.tensor(
|
| 974 |
+
self.reward_model(
|
| 975 |
+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
| 976 |
+
),
|
| 977 |
+
dtype=torch.float,
|
| 978 |
+
).to(postprocessed_query_response.device)
|
| 979 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
| 980 |
+
|
| 981 |
+
if sampling:
|
| 982 |
+
break
|
| 983 |
+
df = pd.DataFrame(table)
|
| 984 |
+
|
| 985 |
+
if self.accelerator.is_main_process:
|
| 986 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
| 987 |
+
if "wandb" in args.report_to:
|
| 988 |
+
import wandb
|
| 989 |
+
|
| 990 |
+
if wandb.run is not None:
|
| 991 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 992 |
+
|
| 993 |
+
if "comet_ml" in args.report_to:
|
| 994 |
+
log_table_to_comet_experiment(
|
| 995 |
+
name="completions.csv",
|
| 996 |
+
table=df,
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
def create_model_card(
|
| 1000 |
+
self,
|
| 1001 |
+
model_name: Optional[str] = None,
|
| 1002 |
+
dataset_name: Optional[str] = None,
|
| 1003 |
+
tags: Union[str, list[str], None] = None,
|
| 1004 |
+
):
|
| 1005 |
+
"""
|
| 1006 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1010 |
+
Name of the model.
|
| 1011 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1012 |
+
Name of the dataset used for training.
|
| 1013 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1014 |
+
Tags to be associated with the model card.
|
| 1015 |
+
"""
|
| 1016 |
+
if not self.is_world_process_zero():
|
| 1017 |
+
return
|
| 1018 |
+
|
| 1019 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1020 |
+
base_model = self.model.config._name_or_path
|
| 1021 |
+
else:
|
| 1022 |
+
base_model = None
|
| 1023 |
+
|
| 1024 |
+
tags = tags or []
|
| 1025 |
+
if isinstance(tags, str):
|
| 1026 |
+
tags = [tags]
|
| 1027 |
+
|
| 1028 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1029 |
+
tags.append("unsloth")
|
| 1030 |
+
|
| 1031 |
+
citation = textwrap.dedent("""\
|
| 1032 |
+
@inproceedings{ahmadian2024back,
|
| 1033 |
+
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
|
| 1034 |
+
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
|
| 1035 |
+
year = 2024,
|
| 1036 |
+
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
|
| 1037 |
+
publisher = {Association for Computational Linguistics},
|
| 1038 |
+
pages = {12248--12267},
|
| 1039 |
+
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
|
| 1040 |
+
}""")
|
| 1041 |
+
|
| 1042 |
+
model_card = generate_model_card(
|
| 1043 |
+
base_model=base_model,
|
| 1044 |
+
model_name=model_name,
|
| 1045 |
+
hub_model_id=self.hub_model_id,
|
| 1046 |
+
dataset_name=dataset_name,
|
| 1047 |
+
tags=tags,
|
| 1048 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1049 |
+
comet_url=get_comet_experiment_url(),
|
| 1050 |
+
trainer_name="RLOO",
|
| 1051 |
+
trainer_citation=citation,
|
| 1052 |
+
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
|
| 1053 |
+
paper_id="2402.14740",
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1057 |
+
class UnslothRLOOTrainer(_UnslothRLOOTrainer):
|
| 1058 |
+
"""
|
| 1059 |
+
|
| 1060 |
+
"""
|
| 1061 |
+
def __init__(
|
| 1062 |
+
self,
|
| 1063 |
+
config,
|
| 1064 |
+
processing_class,
|
| 1065 |
+
policy,
|
| 1066 |
+
ref_policy,
|
| 1067 |
+
reward_model,
|
| 1068 |
+
train_dataset,
|
| 1069 |
+
data_collator = None,
|
| 1070 |
+
eval_dataset = None,
|
| 1071 |
+
callbacks = None,
|
| 1072 |
+
**kwargs
|
| 1073 |
+
):
|
| 1074 |
+
if args is None: args = UnslothRLOOConfig()
|
| 1075 |
+
_output_logits = False
|
| 1076 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1077 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1078 |
+
if _output_logits:
|
| 1079 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1080 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1081 |
+
pass
|
| 1082 |
+
else:
|
| 1083 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1084 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1085 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1086 |
+
max_seq_length = model.max_seq_length
|
| 1087 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1088 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1089 |
+
model.for_training()
|
| 1090 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1091 |
+
if 'processing_class' in locals():
|
| 1092 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1093 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1094 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1095 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1096 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1097 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1098 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1099 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1100 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1101 |
+
else:
|
| 1102 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1103 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1104 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1105 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1106 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1107 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1108 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1109 |
+
else:
|
| 1110 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1111 |
+
other_metrics = []
|
| 1112 |
+
|
| 1113 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1114 |
+
PatchRLStatistics('rloo_trainer', other_metrics)
|
| 1115 |
+
|
| 1116 |
+
super().__init__(
|
| 1117 |
+
config = config,
|
| 1118 |
+
processing_class = processing_class,
|
| 1119 |
+
policy = policy,
|
| 1120 |
+
ref_policy = ref_policy,
|
| 1121 |
+
reward_model = reward_model,
|
| 1122 |
+
train_dataset = train_dataset,
|
| 1123 |
+
data_collator = data_collator,
|
| 1124 |
+
eval_dataset = eval_dataset,
|
| 1125 |
+
callbacks = callbacks,**kwargs)
|
| 1126 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1127 |
+
self.neftune_hook_handle.remove()
|
| 1128 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1129 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1130 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1131 |
+
pass
|
| 1132 |
+
|
| 1133 |
+
pass
|
unsloth_compiled_cache/UnslothRewardTrainer.py
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothRewardConfig(RewardConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`RewardTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 54 |
+
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
| 55 |
+
limit. This argument is required if you want to use the default data collator.
|
| 56 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether to disable dropout in the model.
|
| 58 |
+
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
| 59 |
+
Number of processes to use for processing the dataset.
|
| 60 |
+
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
| 61 |
+
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
| 62 |
+
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
| 63 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
| 64 |
+
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
|
| 65 |
+
the dataset is pretokenized.
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 69 |
+
default = None,
|
| 70 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 71 |
+
)
|
| 72 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 73 |
+
default = -1,
|
| 74 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 75 |
+
)
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
output_dir = None,
|
| 79 |
+
overwrite_output_dir = None,
|
| 80 |
+
do_train = False,
|
| 81 |
+
do_eval = False,
|
| 82 |
+
do_predict = False,
|
| 83 |
+
eval_strategy = 'no',
|
| 84 |
+
prediction_loss_only = False,
|
| 85 |
+
per_device_train_batch_size = 4,
|
| 86 |
+
per_device_eval_batch_size = 4,
|
| 87 |
+
per_gpu_train_batch_size = None,
|
| 88 |
+
per_gpu_eval_batch_size = None,
|
| 89 |
+
gradient_accumulation_steps = 2,
|
| 90 |
+
eval_accumulation_steps = 2,
|
| 91 |
+
eval_delay = 0,
|
| 92 |
+
torch_empty_cache_steps = 250,
|
| 93 |
+
learning_rate = 5e-05,
|
| 94 |
+
weight_decay = 0.01,
|
| 95 |
+
adam_beta1 = 0.9,
|
| 96 |
+
adam_beta2 = 0.999,
|
| 97 |
+
adam_epsilon = 1e-08,
|
| 98 |
+
max_grad_norm = 1.0,
|
| 99 |
+
num_train_epochs = 3.0,
|
| 100 |
+
max_steps = -1,
|
| 101 |
+
lr_scheduler_type = 'linear',
|
| 102 |
+
warmup_ratio = 0.1,
|
| 103 |
+
warmup_steps = 0,
|
| 104 |
+
log_level = 'passive',
|
| 105 |
+
log_level_replica = 'warning',
|
| 106 |
+
log_on_each_node = True,
|
| 107 |
+
logging_dir = None,
|
| 108 |
+
logging_strategy = 'steps',
|
| 109 |
+
logging_first_step = False,
|
| 110 |
+
logging_steps = 1,
|
| 111 |
+
logging_nan_inf_filter = False,
|
| 112 |
+
save_strategy = 'steps',
|
| 113 |
+
save_steps = 500,
|
| 114 |
+
save_total_limit = None,
|
| 115 |
+
save_safetensors = True,
|
| 116 |
+
save_on_each_node = False,
|
| 117 |
+
save_only_model = False,
|
| 118 |
+
restore_callback_states_from_checkpoint = False,
|
| 119 |
+
no_cuda = False,
|
| 120 |
+
use_cpu = False,
|
| 121 |
+
use_mps_device = False,
|
| 122 |
+
seed = 3407,
|
| 123 |
+
data_seed = 3407,
|
| 124 |
+
jit_mode_eval = False,
|
| 125 |
+
use_ipex = False,
|
| 126 |
+
bf16 = False,
|
| 127 |
+
fp16 = False,
|
| 128 |
+
fp16_opt_level = 'O1',
|
| 129 |
+
half_precision_backend = 'auto',
|
| 130 |
+
bf16_full_eval = False,
|
| 131 |
+
fp16_full_eval = False,
|
| 132 |
+
tf32 = None,
|
| 133 |
+
local_rank = -1,
|
| 134 |
+
ddp_backend = None,
|
| 135 |
+
tpu_num_cores = None,
|
| 136 |
+
tpu_metrics_debug = False,
|
| 137 |
+
debug = '',
|
| 138 |
+
dataloader_drop_last = False,
|
| 139 |
+
eval_steps = None,
|
| 140 |
+
dataloader_num_workers = 0,
|
| 141 |
+
dataloader_prefetch_factor = None,
|
| 142 |
+
past_index = -1,
|
| 143 |
+
run_name = None,
|
| 144 |
+
disable_tqdm = None,
|
| 145 |
+
remove_unused_columns = False,
|
| 146 |
+
label_names = None,
|
| 147 |
+
load_best_model_at_end = False,
|
| 148 |
+
metric_for_best_model = None,
|
| 149 |
+
greater_is_better = None,
|
| 150 |
+
ignore_data_skip = False,
|
| 151 |
+
fsdp = '',
|
| 152 |
+
fsdp_min_num_params = 0,
|
| 153 |
+
fsdp_config = None,
|
| 154 |
+
tp_size = 0,
|
| 155 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 156 |
+
accelerator_config = None,
|
| 157 |
+
deepspeed = None,
|
| 158 |
+
label_smoothing_factor = 0.0,
|
| 159 |
+
optim = 'adamw_8bit',
|
| 160 |
+
optim_args = None,
|
| 161 |
+
adafactor = False,
|
| 162 |
+
group_by_length = False,
|
| 163 |
+
length_column_name = 'length',
|
| 164 |
+
report_to = None,
|
| 165 |
+
ddp_find_unused_parameters = None,
|
| 166 |
+
ddp_bucket_cap_mb = None,
|
| 167 |
+
ddp_broadcast_buffers = None,
|
| 168 |
+
dataloader_pin_memory = True,
|
| 169 |
+
dataloader_persistent_workers = False,
|
| 170 |
+
skip_memory_metrics = True,
|
| 171 |
+
use_legacy_prediction_loop = False,
|
| 172 |
+
push_to_hub = False,
|
| 173 |
+
resume_from_checkpoint = None,
|
| 174 |
+
hub_model_id = None,
|
| 175 |
+
hub_strategy = 'every_save',
|
| 176 |
+
hub_token = None,
|
| 177 |
+
hub_private_repo = None,
|
| 178 |
+
hub_always_push = False,
|
| 179 |
+
gradient_checkpointing = False,
|
| 180 |
+
gradient_checkpointing_kwargs = None,
|
| 181 |
+
include_inputs_for_metrics = False,
|
| 182 |
+
eval_do_concat_batches = True,
|
| 183 |
+
fp16_backend = 'auto',
|
| 184 |
+
evaluation_strategy = None,
|
| 185 |
+
push_to_hub_model_id = None,
|
| 186 |
+
push_to_hub_organization = None,
|
| 187 |
+
push_to_hub_token = None,
|
| 188 |
+
mp_parameters = '',
|
| 189 |
+
auto_find_batch_size = False,
|
| 190 |
+
full_determinism = False,
|
| 191 |
+
torchdynamo = None,
|
| 192 |
+
ray_scope = 'last',
|
| 193 |
+
ddp_timeout = 1800,
|
| 194 |
+
torch_compile = False,
|
| 195 |
+
torch_compile_backend = None,
|
| 196 |
+
torch_compile_mode = None,
|
| 197 |
+
dispatch_batches = None,
|
| 198 |
+
split_batches = None,
|
| 199 |
+
include_tokens_per_second = False,
|
| 200 |
+
include_num_input_tokens_seen = False,
|
| 201 |
+
neftune_noise_alpha = None,
|
| 202 |
+
optim_target_modules = None,
|
| 203 |
+
batch_eval_metrics = False,
|
| 204 |
+
eval_on_start = False,
|
| 205 |
+
use_liger_kernel = False,
|
| 206 |
+
eval_use_gather_object = False,
|
| 207 |
+
average_tokens_across_devices = False,
|
| 208 |
+
max_length = 1024,
|
| 209 |
+
disable_dropout = True,
|
| 210 |
+
dataset_num_proc = None,
|
| 211 |
+
center_rewards_coefficient = None,
|
| 212 |
+
vllm_sampling_params = None,
|
| 213 |
+
unsloth_num_chunks = -1,
|
| 214 |
+
**kwargs,
|
| 215 |
+
):
|
| 216 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 217 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 218 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 219 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 220 |
+
save_strategy = 'no'
|
| 221 |
+
if dataset_num_proc is None:
|
| 222 |
+
from multiprocessing import cpu_count
|
| 223 |
+
dataset_num_proc = cpu_count()
|
| 224 |
+
|
| 225 |
+
super().__init__(
|
| 226 |
+
output_dir = output_dir,
|
| 227 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 228 |
+
do_train = do_train,
|
| 229 |
+
do_eval = do_eval,
|
| 230 |
+
do_predict = do_predict,
|
| 231 |
+
eval_strategy = eval_strategy,
|
| 232 |
+
prediction_loss_only = prediction_loss_only,
|
| 233 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 234 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 235 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 236 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 237 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 238 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 239 |
+
eval_delay = eval_delay,
|
| 240 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 241 |
+
learning_rate = learning_rate,
|
| 242 |
+
weight_decay = weight_decay,
|
| 243 |
+
adam_beta1 = adam_beta1,
|
| 244 |
+
adam_beta2 = adam_beta2,
|
| 245 |
+
adam_epsilon = adam_epsilon,
|
| 246 |
+
max_grad_norm = max_grad_norm,
|
| 247 |
+
num_train_epochs = num_train_epochs,
|
| 248 |
+
max_steps = max_steps,
|
| 249 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 250 |
+
warmup_ratio = warmup_ratio,
|
| 251 |
+
warmup_steps = warmup_steps,
|
| 252 |
+
log_level = log_level,
|
| 253 |
+
log_level_replica = log_level_replica,
|
| 254 |
+
log_on_each_node = log_on_each_node,
|
| 255 |
+
logging_dir = logging_dir,
|
| 256 |
+
logging_strategy = logging_strategy,
|
| 257 |
+
logging_first_step = logging_first_step,
|
| 258 |
+
logging_steps = logging_steps,
|
| 259 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 260 |
+
save_strategy = save_strategy,
|
| 261 |
+
save_steps = save_steps,
|
| 262 |
+
save_total_limit = save_total_limit,
|
| 263 |
+
save_safetensors = save_safetensors,
|
| 264 |
+
save_on_each_node = save_on_each_node,
|
| 265 |
+
save_only_model = save_only_model,
|
| 266 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 267 |
+
no_cuda = no_cuda,
|
| 268 |
+
use_cpu = use_cpu,
|
| 269 |
+
use_mps_device = use_mps_device,
|
| 270 |
+
seed = seed,
|
| 271 |
+
data_seed = data_seed,
|
| 272 |
+
jit_mode_eval = jit_mode_eval,
|
| 273 |
+
use_ipex = use_ipex,
|
| 274 |
+
bf16 = bf16,
|
| 275 |
+
fp16 = fp16,
|
| 276 |
+
fp16_opt_level = fp16_opt_level,
|
| 277 |
+
half_precision_backend = half_precision_backend,
|
| 278 |
+
bf16_full_eval = bf16_full_eval,
|
| 279 |
+
fp16_full_eval = fp16_full_eval,
|
| 280 |
+
tf32 = tf32,
|
| 281 |
+
local_rank = local_rank,
|
| 282 |
+
ddp_backend = ddp_backend,
|
| 283 |
+
tpu_num_cores = tpu_num_cores,
|
| 284 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 285 |
+
debug = debug,
|
| 286 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 287 |
+
eval_steps = eval_steps,
|
| 288 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 289 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 290 |
+
past_index = past_index,
|
| 291 |
+
run_name = run_name,
|
| 292 |
+
disable_tqdm = disable_tqdm,
|
| 293 |
+
remove_unused_columns = remove_unused_columns,
|
| 294 |
+
label_names = label_names,
|
| 295 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 296 |
+
metric_for_best_model = metric_for_best_model,
|
| 297 |
+
greater_is_better = greater_is_better,
|
| 298 |
+
ignore_data_skip = ignore_data_skip,
|
| 299 |
+
fsdp = fsdp,
|
| 300 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 301 |
+
fsdp_config = fsdp_config,
|
| 302 |
+
tp_size = tp_size,
|
| 303 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 304 |
+
accelerator_config = accelerator_config,
|
| 305 |
+
deepspeed = deepspeed,
|
| 306 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 307 |
+
optim = optim,
|
| 308 |
+
optim_args = optim_args,
|
| 309 |
+
adafactor = adafactor,
|
| 310 |
+
group_by_length = group_by_length,
|
| 311 |
+
length_column_name = length_column_name,
|
| 312 |
+
report_to = report_to,
|
| 313 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 314 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 315 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 316 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 317 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 318 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 319 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 320 |
+
push_to_hub = push_to_hub,
|
| 321 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 322 |
+
hub_model_id = hub_model_id,
|
| 323 |
+
hub_strategy = hub_strategy,
|
| 324 |
+
hub_token = hub_token,
|
| 325 |
+
hub_private_repo = hub_private_repo,
|
| 326 |
+
hub_always_push = hub_always_push,
|
| 327 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 328 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 329 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 330 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 331 |
+
fp16_backend = fp16_backend,
|
| 332 |
+
evaluation_strategy = evaluation_strategy,
|
| 333 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 334 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 335 |
+
push_to_hub_token = push_to_hub_token,
|
| 336 |
+
mp_parameters = mp_parameters,
|
| 337 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 338 |
+
full_determinism = full_determinism,
|
| 339 |
+
torchdynamo = torchdynamo,
|
| 340 |
+
ray_scope = ray_scope,
|
| 341 |
+
ddp_timeout = ddp_timeout,
|
| 342 |
+
torch_compile = torch_compile,
|
| 343 |
+
torch_compile_backend = torch_compile_backend,
|
| 344 |
+
torch_compile_mode = torch_compile_mode,
|
| 345 |
+
dispatch_batches = dispatch_batches,
|
| 346 |
+
split_batches = split_batches,
|
| 347 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 348 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 349 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 350 |
+
optim_target_modules = optim_target_modules,
|
| 351 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 352 |
+
eval_on_start = eval_on_start,
|
| 353 |
+
use_liger_kernel = use_liger_kernel,
|
| 354 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 355 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 356 |
+
max_length = max_length,
|
| 357 |
+
disable_dropout = disable_dropout,
|
| 358 |
+
dataset_num_proc = dataset_num_proc,
|
| 359 |
+
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
| 360 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 361 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 362 |
+
pass
|
| 363 |
+
|
| 364 |
+
class _UnslothRewardTrainer(Trainer):
|
| 365 |
+
_tag_names = ["trl", "reward-trainer"]
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 370 |
+
args: Optional[RewardConfig] = None,
|
| 371 |
+
data_collator: Optional[DataCollator] = None,
|
| 372 |
+
train_dataset: Optional[Dataset] = None,
|
| 373 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 374 |
+
processing_class: Optional[
|
| 375 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 376 |
+
] = None,
|
| 377 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 378 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 379 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 380 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
| 381 |
+
None,
|
| 382 |
+
None,
|
| 383 |
+
),
|
| 384 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 385 |
+
peft_config: Optional[dict] = None,
|
| 386 |
+
):
|
| 387 |
+
"""
|
| 388 |
+
Initialize RewardTrainer.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
model (`transformers.PreTrainedModel`):
|
| 392 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 393 |
+
args (`RewardConfig`):
|
| 394 |
+
The arguments to use for training.
|
| 395 |
+
data_collator (`transformers.DataCollator`):
|
| 396 |
+
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
| 397 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 398 |
+
train_dataset (`datasets.Dataset`):
|
| 399 |
+
The dataset to use for training.
|
| 400 |
+
eval_dataset (`datasets.Dataset`):
|
| 401 |
+
The dataset to use for evaluation.
|
| 402 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 403 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 404 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 405 |
+
reuse the fine-tuned model.
|
| 406 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 407 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 408 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
| 409 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
| 410 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 411 |
+
The callbacks to use for training.
|
| 412 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 413 |
+
The optimizer and scheduler to use for training.
|
| 414 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 415 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 416 |
+
peft_config (`dict`, defaults to `None`):
|
| 417 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 418 |
+
"""
|
| 419 |
+
if not is_peft_available() and peft_config is not None:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 422 |
+
)
|
| 423 |
+
elif is_peft_available() and peft_config is not None:
|
| 424 |
+
if not isinstance(model, PeftModel):
|
| 425 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
| 426 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
| 427 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 431 |
+
|
| 432 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 433 |
+
warnings.warn(
|
| 434 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
| 435 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
| 436 |
+
UserWarning,
|
| 437 |
+
)
|
| 438 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 439 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 440 |
+
|
| 441 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 442 |
+
|
| 443 |
+
model = model
|
| 444 |
+
|
| 445 |
+
# Disable dropout in the model
|
| 446 |
+
if args.disable_dropout:
|
| 447 |
+
disable_dropout_in_model(model)
|
| 448 |
+
|
| 449 |
+
if compute_metrics is None:
|
| 450 |
+
compute_metrics = compute_accuracy
|
| 451 |
+
|
| 452 |
+
if data_collator is None:
|
| 453 |
+
if processing_class is None:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
max_length = args.max_length
|
| 459 |
+
|
| 460 |
+
data_collator = RewardDataCollatorWithPadding(processing_class)
|
| 461 |
+
|
| 462 |
+
if args.remove_unused_columns:
|
| 463 |
+
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
| 464 |
+
args.remove_unused_columns = False
|
| 465 |
+
except FrozenInstanceError:
|
| 466 |
+
args = replace(args, remove_unused_columns=False)
|
| 467 |
+
# warn users
|
| 468 |
+
warnings.warn(
|
| 469 |
+
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
| 470 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 471 |
+
UserWarning,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
self.use_reward_data_collator = True
|
| 475 |
+
else:
|
| 476 |
+
self.use_reward_data_collator = False
|
| 477 |
+
|
| 478 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 479 |
+
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
| 480 |
+
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
| 481 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 482 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 483 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 484 |
+
# issued.
|
| 485 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 486 |
+
|
| 487 |
+
if "input_ids_chosen" not in train_dataset.column_names:
|
| 488 |
+
with PartialState().local_main_process_first():
|
| 489 |
+
fn_kwargs = {"tokenizer": processing_class}
|
| 490 |
+
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
| 491 |
+
train_dataset = train_dataset.map(
|
| 492 |
+
_tokenize,
|
| 493 |
+
batched=True,
|
| 494 |
+
fn_kwargs=fn_kwargs,
|
| 495 |
+
num_proc=args.dataset_num_proc,
|
| 496 |
+
)
|
| 497 |
+
# This filter is important because otherwise you get samples that exceed the model's context length and
|
| 498 |
+
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
| 499 |
+
# user might get surprised if N samples are missing from training.
|
| 500 |
+
train_dataset = train_dataset.filter(
|
| 501 |
+
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
| 502 |
+
num_proc=args.dataset_num_proc,
|
| 503 |
+
)
|
| 504 |
+
if eval_dataset is not None:
|
| 505 |
+
eval_dataset = eval_dataset.map(
|
| 506 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
| 507 |
+
)
|
| 508 |
+
eval_dataset = eval_dataset.map(
|
| 509 |
+
_tokenize,
|
| 510 |
+
fn_kwargs=fn_kwargs,
|
| 511 |
+
batched=True,
|
| 512 |
+
num_proc=args.dataset_num_proc,
|
| 513 |
+
)
|
| 514 |
+
# This filter is important because otherwise you get samples that exceed the model's context length and
|
| 515 |
+
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
| 516 |
+
# user might get surprised if N samples are missing from training.
|
| 517 |
+
eval_dataset = eval_dataset.filter(
|
| 518 |
+
lambda x: len(x["input_ids_chosen"]) <= max_length
|
| 519 |
+
and len(x["input_ids_rejected"]) <= max_length,
|
| 520 |
+
num_proc=args.dataset_num_proc,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
super().__init__(
|
| 524 |
+
model=model,
|
| 525 |
+
args=args,
|
| 526 |
+
data_collator=data_collator,
|
| 527 |
+
train_dataset=train_dataset,
|
| 528 |
+
eval_dataset=eval_dataset,
|
| 529 |
+
processing_class=processing_class,
|
| 530 |
+
model_init=model_init,
|
| 531 |
+
compute_metrics=compute_metrics,
|
| 532 |
+
callbacks=callbacks,
|
| 533 |
+
optimizers=optimizers,
|
| 534 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 538 |
+
if hasattr(self.model, "add_model_tags"):
|
| 539 |
+
self.model.add_model_tags(self._tag_names)
|
| 540 |
+
|
| 541 |
+
def compute_loss(
|
| 542 |
+
self,
|
| 543 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 544 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 545 |
+
return_outputs=False,
|
| 546 |
+
num_items_in_batch=None,
|
| 547 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 548 |
+
rewards_chosen = model(
|
| 549 |
+
input_ids=inputs["input_ids_chosen"],
|
| 550 |
+
attention_mask=inputs["attention_mask_chosen"],
|
| 551 |
+
return_dict=True,
|
| 552 |
+
)["logits"]
|
| 553 |
+
rewards_rejected = model(
|
| 554 |
+
input_ids=inputs["input_ids_rejected"],
|
| 555 |
+
attention_mask=inputs["attention_mask_rejected"],
|
| 556 |
+
return_dict=True,
|
| 557 |
+
)["logits"]
|
| 558 |
+
# calculate loss, optionally modulate with margin
|
| 559 |
+
if "margin" in inputs:
|
| 560 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
| 561 |
+
else:
|
| 562 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
| 563 |
+
|
| 564 |
+
if self.args.center_rewards_coefficient is not None:
|
| 565 |
+
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
| 566 |
+
|
| 567 |
+
if return_outputs:
|
| 568 |
+
return loss, {
|
| 569 |
+
"rewards_chosen": rewards_chosen,
|
| 570 |
+
"rewards_rejected": rewards_rejected,
|
| 571 |
+
}
|
| 572 |
+
return loss
|
| 573 |
+
|
| 574 |
+
def prediction_step(
|
| 575 |
+
self,
|
| 576 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 577 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 578 |
+
prediction_loss_only: bool,
|
| 579 |
+
ignore_keys: Optional[list[str]] = None,
|
| 580 |
+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 581 |
+
inputs = self._prepare_inputs(inputs)
|
| 582 |
+
if ignore_keys is None:
|
| 583 |
+
if hasattr(self.model, "config"):
|
| 584 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
| 585 |
+
else:
|
| 586 |
+
ignore_keys = []
|
| 587 |
+
|
| 588 |
+
with torch.no_grad():
|
| 589 |
+
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
| 590 |
+
|
| 591 |
+
if prediction_loss_only:
|
| 592 |
+
return (loss, None, None)
|
| 593 |
+
|
| 594 |
+
loss = loss.detach()
|
| 595 |
+
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
| 596 |
+
logits = nested_detach(logits)
|
| 597 |
+
# Stack accepted against rejected, mean over logits
|
| 598 |
+
# and softmax to get preferences between accepted and rejected to sum to 1
|
| 599 |
+
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
| 600 |
+
|
| 601 |
+
labels = torch.zeros(logits.shape[0])
|
| 602 |
+
labels = self._prepare_inputs(labels)
|
| 603 |
+
|
| 604 |
+
return loss, logits, labels
|
| 605 |
+
|
| 606 |
+
def evaluate(self, *args, **kwargs):
|
| 607 |
+
num_print_samples = kwargs.pop("num_print_samples", 4)
|
| 608 |
+
self.visualize_samples(num_print_samples)
|
| 609 |
+
return super().evaluate(*args, **kwargs)
|
| 610 |
+
|
| 611 |
+
def visualize_samples(self, num_print_samples: int):
|
| 612 |
+
"""
|
| 613 |
+
Visualize the reward model logits prediction
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
num_print_samples (`int`, defaults to `4`):
|
| 617 |
+
The number of samples to print. Set to `-1` to print all samples.
|
| 618 |
+
"""
|
| 619 |
+
eval_dataloader = self.get_eval_dataloader()
|
| 620 |
+
table = defaultdict(list)
|
| 621 |
+
for _, inputs in enumerate(eval_dataloader):
|
| 622 |
+
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
| 623 |
+
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
| 624 |
+
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
| 625 |
+
table["chosen_text"].extend(gather_object(chosen_text))
|
| 626 |
+
table["rejected_text"].extend(gather_object(rejected_text))
|
| 627 |
+
table["logits"].extend(
|
| 628 |
+
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
| 629 |
+
)
|
| 630 |
+
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
| 631 |
+
break
|
| 632 |
+
df = pd.DataFrame(table)
|
| 633 |
+
if self.accelerator.process_index == 0:
|
| 634 |
+
print_rich_table(df[:num_print_samples])
|
| 635 |
+
if "wandb" in self.args.report_to:
|
| 636 |
+
import wandb
|
| 637 |
+
|
| 638 |
+
if wandb.run is not None:
|
| 639 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 640 |
+
|
| 641 |
+
if "comet_ml" in self.args.report_to:
|
| 642 |
+
log_table_to_comet_experiment(
|
| 643 |
+
name="completions.csv",
|
| 644 |
+
table=df,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
def create_model_card(
|
| 648 |
+
self,
|
| 649 |
+
model_name: Optional[str] = None,
|
| 650 |
+
dataset_name: Optional[str] = None,
|
| 651 |
+
tags: Union[str, list[str], None] = None,
|
| 652 |
+
):
|
| 653 |
+
"""
|
| 654 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 658 |
+
Name of the model.
|
| 659 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 660 |
+
Name of the dataset used for training.
|
| 661 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 662 |
+
Tags to be associated with the model card.
|
| 663 |
+
"""
|
| 664 |
+
if not self.is_world_process_zero():
|
| 665 |
+
return
|
| 666 |
+
|
| 667 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 668 |
+
base_model = self.model.config._name_or_path
|
| 669 |
+
else:
|
| 670 |
+
base_model = None
|
| 671 |
+
|
| 672 |
+
tags = tags or []
|
| 673 |
+
if isinstance(tags, str):
|
| 674 |
+
tags = [tags]
|
| 675 |
+
|
| 676 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 677 |
+
tags.append("unsloth")
|
| 678 |
+
|
| 679 |
+
model_card = generate_model_card(
|
| 680 |
+
base_model=base_model,
|
| 681 |
+
model_name=model_name,
|
| 682 |
+
hub_model_id=self.hub_model_id,
|
| 683 |
+
dataset_name=dataset_name,
|
| 684 |
+
tags=tags,
|
| 685 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 686 |
+
comet_url=get_comet_experiment_url(),
|
| 687 |
+
trainer_name="Reward",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 691 |
+
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
"""
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
model = None,
|
| 698 |
+
args = None,
|
| 699 |
+
data_collator = None,
|
| 700 |
+
train_dataset = None,
|
| 701 |
+
eval_dataset = None,
|
| 702 |
+
processing_class = None,
|
| 703 |
+
model_init = None,
|
| 704 |
+
compute_metrics = None,
|
| 705 |
+
callbacks = None,
|
| 706 |
+
preprocess_logits_for_metrics = None,
|
| 707 |
+
peft_config = None,
|
| 708 |
+
**kwargs
|
| 709 |
+
):
|
| 710 |
+
if args is None: args = UnslothRewardConfig()
|
| 711 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 712 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 713 |
+
force_float32 = False
|
| 714 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 715 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 716 |
+
force_float32 = True
|
| 717 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 718 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 719 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 720 |
+
from unsloth_zoo.utils import _get_dtype
|
| 721 |
+
dtype = _get_dtype(dtype)
|
| 722 |
+
float16 = dtype == torch.float16
|
| 723 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 724 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 725 |
+
if force_float32:
|
| 726 |
+
args.fp16 = False
|
| 727 |
+
args.bf16 = False
|
| 728 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 729 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 730 |
+
args.fp16 = float16
|
| 731 |
+
args.bf16 = not float16
|
| 732 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 733 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 734 |
+
args.eval_strategy = 'steps'
|
| 735 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 736 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 737 |
+
if ga_steps is not None and ga_steps > 1:
|
| 738 |
+
from transformers import __version__ as transformers_version
|
| 739 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 740 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 741 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 742 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 743 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 744 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 745 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 746 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 747 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 748 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 749 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 750 |
+
if force_float32:
|
| 751 |
+
args.bf16_full_eval = False
|
| 752 |
+
args.fp16_full_eval = False
|
| 753 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 754 |
+
args.bf16_full_eval = True
|
| 755 |
+
args.fp16_full_eval = False
|
| 756 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 757 |
+
args.bf16_full_eval = args.bf16
|
| 758 |
+
args.fp16_full_eval = args.fp16
|
| 759 |
+
_output_logits = False
|
| 760 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 761 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 762 |
+
if _output_logits:
|
| 763 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 764 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 765 |
+
pass
|
| 766 |
+
else:
|
| 767 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 768 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 769 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 770 |
+
max_seq_length = model.max_seq_length
|
| 771 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 772 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 773 |
+
model.for_training()
|
| 774 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 775 |
+
if 'processing_class' in locals():
|
| 776 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 777 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 778 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 779 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 780 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 781 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 782 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 783 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 784 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 785 |
+
else:
|
| 786 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 787 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 788 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 789 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 790 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 791 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 792 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 793 |
+
else:
|
| 794 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 795 |
+
other_metrics = []
|
| 796 |
+
|
| 797 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 798 |
+
PatchRLStatistics('reward_trainer', other_metrics)
|
| 799 |
+
|
| 800 |
+
super().__init__(
|
| 801 |
+
model = model,
|
| 802 |
+
args = args,
|
| 803 |
+
data_collator = data_collator,
|
| 804 |
+
train_dataset = train_dataset,
|
| 805 |
+
eval_dataset = eval_dataset,
|
| 806 |
+
processing_class = processing_class,
|
| 807 |
+
model_init = model_init,
|
| 808 |
+
compute_metrics = compute_metrics,
|
| 809 |
+
callbacks = callbacks,
|
| 810 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 811 |
+
peft_config = peft_config,**kwargs)
|
| 812 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 813 |
+
self.neftune_hook_handle.remove()
|
| 814 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 815 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 816 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 817 |
+
pass
|
| 818 |
+
|
| 819 |
+
pass
|
unsloth_compiled_cache/UnslothSFTTrainer.py
ADDED
|
@@ -0,0 +1,1031 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothSFTConfig(SFTConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`SFTTrainer`].
|
| 47 |
+
|
| 48 |
+
Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
|
| 49 |
+
[`~transformers.TrainingArguments`] documentation.
|
| 50 |
+
|
| 51 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 52 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 53 |
+
command line.
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
> Parameters that control the model
|
| 57 |
+
|
| 58 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 59 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 60 |
+
argument of the [`SFTTrainer`] is provided as a string.
|
| 61 |
+
use_liger (`bool`, *optional*, defaults to `False`):
|
| 62 |
+
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
|
| 63 |
+
|
| 64 |
+
> Parameters that control the data preprocessing
|
| 65 |
+
|
| 66 |
+
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
| 67 |
+
Name of the column that contains text data in the dataset.
|
| 68 |
+
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 69 |
+
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
| 70 |
+
`skip_prepare_dataset`.
|
| 71 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 72 |
+
Number of processes to use for processing the dataset.
|
| 73 |
+
max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 74 |
+
Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
|
| 75 |
+
right.
|
| 76 |
+
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
| 77 |
+
packing (`bool`, *optional*, defaults to `False`):
|
| 78 |
+
Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
|
| 79 |
+
length.
|
| 80 |
+
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
| 81 |
+
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
| 82 |
+
|
| 83 |
+
> Parameters that control the training
|
| 84 |
+
|
| 85 |
+
learning_rate (`float`, *optional*, defaults to `2e-5`):
|
| 86 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 87 |
+
[`~transformers.TrainingArguments`].
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 91 |
+
default = None,
|
| 92 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 93 |
+
)
|
| 94 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 95 |
+
default = -1,
|
| 96 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 97 |
+
)
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
output_dir = None,
|
| 101 |
+
overwrite_output_dir = None,
|
| 102 |
+
do_train = False,
|
| 103 |
+
do_eval = False,
|
| 104 |
+
do_predict = False,
|
| 105 |
+
eval_strategy = 'no',
|
| 106 |
+
prediction_loss_only = False,
|
| 107 |
+
per_device_train_batch_size = 4,
|
| 108 |
+
per_device_eval_batch_size = 4,
|
| 109 |
+
per_gpu_train_batch_size = None,
|
| 110 |
+
per_gpu_eval_batch_size = None,
|
| 111 |
+
gradient_accumulation_steps = 2,
|
| 112 |
+
eval_accumulation_steps = 2,
|
| 113 |
+
eval_delay = 0,
|
| 114 |
+
torch_empty_cache_steps = 250,
|
| 115 |
+
learning_rate = 5e-05,
|
| 116 |
+
weight_decay = 0.01,
|
| 117 |
+
adam_beta1 = 0.9,
|
| 118 |
+
adam_beta2 = 0.999,
|
| 119 |
+
adam_epsilon = 1e-08,
|
| 120 |
+
max_grad_norm = 1.0,
|
| 121 |
+
num_train_epochs = 3.0,
|
| 122 |
+
max_steps = -1,
|
| 123 |
+
lr_scheduler_type = 'linear',
|
| 124 |
+
warmup_ratio = 0.1,
|
| 125 |
+
warmup_steps = 0,
|
| 126 |
+
log_level = 'passive',
|
| 127 |
+
log_level_replica = 'warning',
|
| 128 |
+
log_on_each_node = True,
|
| 129 |
+
logging_dir = None,
|
| 130 |
+
logging_strategy = 'steps',
|
| 131 |
+
logging_first_step = False,
|
| 132 |
+
logging_steps = 1,
|
| 133 |
+
logging_nan_inf_filter = False,
|
| 134 |
+
save_strategy = 'steps',
|
| 135 |
+
save_steps = 500,
|
| 136 |
+
save_total_limit = None,
|
| 137 |
+
save_safetensors = True,
|
| 138 |
+
save_on_each_node = False,
|
| 139 |
+
save_only_model = False,
|
| 140 |
+
restore_callback_states_from_checkpoint = False,
|
| 141 |
+
no_cuda = False,
|
| 142 |
+
use_cpu = False,
|
| 143 |
+
use_mps_device = False,
|
| 144 |
+
seed = 3407,
|
| 145 |
+
data_seed = 3407,
|
| 146 |
+
jit_mode_eval = False,
|
| 147 |
+
use_ipex = False,
|
| 148 |
+
bf16 = False,
|
| 149 |
+
fp16 = False,
|
| 150 |
+
fp16_opt_level = 'O1',
|
| 151 |
+
half_precision_backend = 'auto',
|
| 152 |
+
bf16_full_eval = False,
|
| 153 |
+
fp16_full_eval = False,
|
| 154 |
+
tf32 = None,
|
| 155 |
+
local_rank = -1,
|
| 156 |
+
ddp_backend = None,
|
| 157 |
+
tpu_num_cores = None,
|
| 158 |
+
tpu_metrics_debug = False,
|
| 159 |
+
debug = '',
|
| 160 |
+
dataloader_drop_last = False,
|
| 161 |
+
eval_steps = None,
|
| 162 |
+
dataloader_num_workers = 0,
|
| 163 |
+
dataloader_prefetch_factor = None,
|
| 164 |
+
past_index = -1,
|
| 165 |
+
run_name = None,
|
| 166 |
+
disable_tqdm = None,
|
| 167 |
+
remove_unused_columns = True,
|
| 168 |
+
label_names = None,
|
| 169 |
+
load_best_model_at_end = False,
|
| 170 |
+
metric_for_best_model = None,
|
| 171 |
+
greater_is_better = None,
|
| 172 |
+
ignore_data_skip = False,
|
| 173 |
+
fsdp = '',
|
| 174 |
+
fsdp_min_num_params = 0,
|
| 175 |
+
fsdp_config = None,
|
| 176 |
+
tp_size = 0,
|
| 177 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 178 |
+
accelerator_config = None,
|
| 179 |
+
deepspeed = None,
|
| 180 |
+
label_smoothing_factor = 0.0,
|
| 181 |
+
optim = 'adamw_8bit',
|
| 182 |
+
optim_args = None,
|
| 183 |
+
adafactor = False,
|
| 184 |
+
group_by_length = False,
|
| 185 |
+
length_column_name = 'length',
|
| 186 |
+
report_to = None,
|
| 187 |
+
ddp_find_unused_parameters = None,
|
| 188 |
+
ddp_bucket_cap_mb = None,
|
| 189 |
+
ddp_broadcast_buffers = None,
|
| 190 |
+
dataloader_pin_memory = True,
|
| 191 |
+
dataloader_persistent_workers = False,
|
| 192 |
+
skip_memory_metrics = True,
|
| 193 |
+
use_legacy_prediction_loop = False,
|
| 194 |
+
push_to_hub = False,
|
| 195 |
+
resume_from_checkpoint = None,
|
| 196 |
+
hub_model_id = None,
|
| 197 |
+
hub_strategy = 'every_save',
|
| 198 |
+
hub_token = None,
|
| 199 |
+
hub_private_repo = None,
|
| 200 |
+
hub_always_push = False,
|
| 201 |
+
gradient_checkpointing = False,
|
| 202 |
+
gradient_checkpointing_kwargs = None,
|
| 203 |
+
include_inputs_for_metrics = False,
|
| 204 |
+
eval_do_concat_batches = True,
|
| 205 |
+
fp16_backend = 'auto',
|
| 206 |
+
evaluation_strategy = None,
|
| 207 |
+
push_to_hub_model_id = None,
|
| 208 |
+
push_to_hub_organization = None,
|
| 209 |
+
push_to_hub_token = None,
|
| 210 |
+
mp_parameters = '',
|
| 211 |
+
auto_find_batch_size = False,
|
| 212 |
+
full_determinism = False,
|
| 213 |
+
torchdynamo = None,
|
| 214 |
+
ray_scope = 'last',
|
| 215 |
+
ddp_timeout = 1800,
|
| 216 |
+
torch_compile = False,
|
| 217 |
+
torch_compile_backend = None,
|
| 218 |
+
torch_compile_mode = None,
|
| 219 |
+
dispatch_batches = None,
|
| 220 |
+
split_batches = None,
|
| 221 |
+
include_tokens_per_second = False,
|
| 222 |
+
include_num_input_tokens_seen = False,
|
| 223 |
+
neftune_noise_alpha = None,
|
| 224 |
+
optim_target_modules = None,
|
| 225 |
+
batch_eval_metrics = False,
|
| 226 |
+
eval_on_start = False,
|
| 227 |
+
use_liger_kernel = False,
|
| 228 |
+
eval_use_gather_object = False,
|
| 229 |
+
average_tokens_across_devices = False,
|
| 230 |
+
model_init_kwargs = None,
|
| 231 |
+
use_liger = False,
|
| 232 |
+
dataset_text_field = 'text',
|
| 233 |
+
dataset_kwargs = None,
|
| 234 |
+
dataset_num_proc = None,
|
| 235 |
+
max_seq_length = None,
|
| 236 |
+
packing = False,
|
| 237 |
+
eval_packing = None,
|
| 238 |
+
dataset_batch_size = None,
|
| 239 |
+
num_of_sequences = None,
|
| 240 |
+
chars_per_token = None,
|
| 241 |
+
vllm_sampling_params = None,
|
| 242 |
+
unsloth_num_chunks = -1,
|
| 243 |
+
**kwargs,
|
| 244 |
+
):
|
| 245 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 246 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 247 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 248 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 249 |
+
save_strategy = 'no'
|
| 250 |
+
if dataset_num_proc is None:
|
| 251 |
+
from multiprocessing import cpu_count
|
| 252 |
+
dataset_num_proc = cpu_count()
|
| 253 |
+
|
| 254 |
+
super().__init__(
|
| 255 |
+
output_dir = output_dir,
|
| 256 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 257 |
+
do_train = do_train,
|
| 258 |
+
do_eval = do_eval,
|
| 259 |
+
do_predict = do_predict,
|
| 260 |
+
eval_strategy = eval_strategy,
|
| 261 |
+
prediction_loss_only = prediction_loss_only,
|
| 262 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 263 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 264 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 265 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 266 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 267 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 268 |
+
eval_delay = eval_delay,
|
| 269 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 270 |
+
learning_rate = learning_rate,
|
| 271 |
+
weight_decay = weight_decay,
|
| 272 |
+
adam_beta1 = adam_beta1,
|
| 273 |
+
adam_beta2 = adam_beta2,
|
| 274 |
+
adam_epsilon = adam_epsilon,
|
| 275 |
+
max_grad_norm = max_grad_norm,
|
| 276 |
+
num_train_epochs = num_train_epochs,
|
| 277 |
+
max_steps = max_steps,
|
| 278 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 279 |
+
warmup_ratio = warmup_ratio,
|
| 280 |
+
warmup_steps = warmup_steps,
|
| 281 |
+
log_level = log_level,
|
| 282 |
+
log_level_replica = log_level_replica,
|
| 283 |
+
log_on_each_node = log_on_each_node,
|
| 284 |
+
logging_dir = logging_dir,
|
| 285 |
+
logging_strategy = logging_strategy,
|
| 286 |
+
logging_first_step = logging_first_step,
|
| 287 |
+
logging_steps = logging_steps,
|
| 288 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 289 |
+
save_strategy = save_strategy,
|
| 290 |
+
save_steps = save_steps,
|
| 291 |
+
save_total_limit = save_total_limit,
|
| 292 |
+
save_safetensors = save_safetensors,
|
| 293 |
+
save_on_each_node = save_on_each_node,
|
| 294 |
+
save_only_model = save_only_model,
|
| 295 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 296 |
+
no_cuda = no_cuda,
|
| 297 |
+
use_cpu = use_cpu,
|
| 298 |
+
use_mps_device = use_mps_device,
|
| 299 |
+
seed = seed,
|
| 300 |
+
data_seed = data_seed,
|
| 301 |
+
jit_mode_eval = jit_mode_eval,
|
| 302 |
+
use_ipex = use_ipex,
|
| 303 |
+
bf16 = bf16,
|
| 304 |
+
fp16 = fp16,
|
| 305 |
+
fp16_opt_level = fp16_opt_level,
|
| 306 |
+
half_precision_backend = half_precision_backend,
|
| 307 |
+
bf16_full_eval = bf16_full_eval,
|
| 308 |
+
fp16_full_eval = fp16_full_eval,
|
| 309 |
+
tf32 = tf32,
|
| 310 |
+
local_rank = local_rank,
|
| 311 |
+
ddp_backend = ddp_backend,
|
| 312 |
+
tpu_num_cores = tpu_num_cores,
|
| 313 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 314 |
+
debug = debug,
|
| 315 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 316 |
+
eval_steps = eval_steps,
|
| 317 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 318 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 319 |
+
past_index = past_index,
|
| 320 |
+
run_name = run_name,
|
| 321 |
+
disable_tqdm = disable_tqdm,
|
| 322 |
+
remove_unused_columns = remove_unused_columns,
|
| 323 |
+
label_names = label_names,
|
| 324 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 325 |
+
metric_for_best_model = metric_for_best_model,
|
| 326 |
+
greater_is_better = greater_is_better,
|
| 327 |
+
ignore_data_skip = ignore_data_skip,
|
| 328 |
+
fsdp = fsdp,
|
| 329 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 330 |
+
fsdp_config = fsdp_config,
|
| 331 |
+
tp_size = tp_size,
|
| 332 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 333 |
+
accelerator_config = accelerator_config,
|
| 334 |
+
deepspeed = deepspeed,
|
| 335 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 336 |
+
optim = optim,
|
| 337 |
+
optim_args = optim_args,
|
| 338 |
+
adafactor = adafactor,
|
| 339 |
+
group_by_length = group_by_length,
|
| 340 |
+
length_column_name = length_column_name,
|
| 341 |
+
report_to = report_to,
|
| 342 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 343 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 344 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 345 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 346 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 347 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 348 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 349 |
+
push_to_hub = push_to_hub,
|
| 350 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 351 |
+
hub_model_id = hub_model_id,
|
| 352 |
+
hub_strategy = hub_strategy,
|
| 353 |
+
hub_token = hub_token,
|
| 354 |
+
hub_private_repo = hub_private_repo,
|
| 355 |
+
hub_always_push = hub_always_push,
|
| 356 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 357 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 358 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 359 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 360 |
+
fp16_backend = fp16_backend,
|
| 361 |
+
evaluation_strategy = evaluation_strategy,
|
| 362 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 363 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 364 |
+
push_to_hub_token = push_to_hub_token,
|
| 365 |
+
mp_parameters = mp_parameters,
|
| 366 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 367 |
+
full_determinism = full_determinism,
|
| 368 |
+
torchdynamo = torchdynamo,
|
| 369 |
+
ray_scope = ray_scope,
|
| 370 |
+
ddp_timeout = ddp_timeout,
|
| 371 |
+
torch_compile = torch_compile,
|
| 372 |
+
torch_compile_backend = torch_compile_backend,
|
| 373 |
+
torch_compile_mode = torch_compile_mode,
|
| 374 |
+
dispatch_batches = dispatch_batches,
|
| 375 |
+
split_batches = split_batches,
|
| 376 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 377 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 378 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 379 |
+
optim_target_modules = optim_target_modules,
|
| 380 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 381 |
+
eval_on_start = eval_on_start,
|
| 382 |
+
use_liger_kernel = use_liger_kernel,
|
| 383 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 384 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 385 |
+
model_init_kwargs = model_init_kwargs,
|
| 386 |
+
use_liger = use_liger,
|
| 387 |
+
dataset_text_field = dataset_text_field,
|
| 388 |
+
dataset_kwargs = dataset_kwargs,
|
| 389 |
+
dataset_num_proc = dataset_num_proc,
|
| 390 |
+
max_seq_length = max_seq_length,
|
| 391 |
+
packing = packing,
|
| 392 |
+
eval_packing = eval_packing,
|
| 393 |
+
dataset_batch_size = dataset_batch_size,
|
| 394 |
+
num_of_sequences = num_of_sequences,
|
| 395 |
+
chars_per_token = chars_per_token,**kwargs)
|
| 396 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 397 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
class _UnslothSFTTrainer(Trainer):
|
| 401 |
+
""""""
|
| 402 |
+
|
| 403 |
+
_tag_names = ["trl", "sft"]
|
| 404 |
+
|
| 405 |
+
@deprecate_kwarg(
|
| 406 |
+
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
|
| 407 |
+
)
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
model: Union[str, nn.Module, PreTrainedModel],
|
| 411 |
+
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
| 412 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
| 413 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 414 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 415 |
+
processing_class: Optional[
|
| 416 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 417 |
+
] = None,
|
| 418 |
+
compute_loss_func: Optional[Callable] = None,
|
| 419 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 420 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 421 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 422 |
+
optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
| 423 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 424 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 425 |
+
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
| 426 |
+
):
|
| 427 |
+
# Args
|
| 428 |
+
if args is None:
|
| 429 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 430 |
+
model_name = model_name.split("/")[-1]
|
| 431 |
+
args = SFTConfig(f"{model_name}-SFT")
|
| 432 |
+
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
| 433 |
+
dict_args = args.to_dict()
|
| 434 |
+
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
| 435 |
+
dict_args.pop("push_to_hub_token")
|
| 436 |
+
args = SFTConfig(**dict_args)
|
| 437 |
+
|
| 438 |
+
# Model
|
| 439 |
+
if args.model_init_kwargs is not None and not isinstance(model, str):
|
| 440 |
+
warnings.warn(
|
| 441 |
+
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
| 442 |
+
"The `model_init_kwargs` will be ignored."
|
| 443 |
+
)
|
| 444 |
+
if isinstance(model, str):
|
| 445 |
+
model = self._create_model_from_path(model, args)
|
| 446 |
+
|
| 447 |
+
# PEFT configuration and model wrapping
|
| 448 |
+
if False:
|
| 449 |
+
model = self._prepare_peft_model(model, peft_config, args)
|
| 450 |
+
|
| 451 |
+
# Handle the tokenizer
|
| 452 |
+
if processing_class is None:
|
| 453 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
| 454 |
+
if processing_class.pad_token is None:
|
| 455 |
+
processing_class.pad_token = processing_class.eos_token # required for padding when collating data
|
| 456 |
+
|
| 457 |
+
# Dataset
|
| 458 |
+
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
| 459 |
+
if preprocess_dataset:
|
| 460 |
+
train_dataset = self._prepare_dataset(
|
| 461 |
+
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
| 462 |
+
)
|
| 463 |
+
if eval_dataset is not None:
|
| 464 |
+
packing = args.packing if args.eval_packing is None else args.eval_packing
|
| 465 |
+
if isinstance(eval_dataset, dict):
|
| 466 |
+
eval_dataset = {
|
| 467 |
+
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
| 468 |
+
for key, dataset in eval_dataset.items()
|
| 469 |
+
}
|
| 470 |
+
else:
|
| 471 |
+
eval_dataset = self._prepare_dataset(
|
| 472 |
+
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Data collator
|
| 476 |
+
if data_collator is None:
|
| 477 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
|
| 478 |
+
|
| 479 |
+
# Initialize the metrics
|
| 480 |
+
self._metrics = defaultdict(list)
|
| 481 |
+
|
| 482 |
+
# Initialize the Trainer. Parent class will handle:
|
| 483 |
+
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
| 484 |
+
# - FSDP setup
|
| 485 |
+
# - Distributed training setup
|
| 486 |
+
# - Optimizer and scheduler creation
|
| 487 |
+
# Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
|
| 488 |
+
super_init_kwargs = {}
|
| 489 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 490 |
+
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
|
| 491 |
+
else:
|
| 492 |
+
if optimizer_cls_and_kwargs is not None:
|
| 493 |
+
warnings.warn(
|
| 494 |
+
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
|
| 495 |
+
"The default optimizer will be used. "
|
| 496 |
+
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
|
| 497 |
+
)
|
| 498 |
+
super().__init__(
|
| 499 |
+
model=model,
|
| 500 |
+
args=args,
|
| 501 |
+
data_collator=data_collator,
|
| 502 |
+
train_dataset=train_dataset,
|
| 503 |
+
eval_dataset=eval_dataset,
|
| 504 |
+
processing_class=processing_class,
|
| 505 |
+
compute_loss_func=compute_loss_func,
|
| 506 |
+
compute_metrics=compute_metrics,
|
| 507 |
+
callbacks=callbacks,
|
| 508 |
+
optimizers=optimizers,
|
| 509 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 510 |
+
**super_init_kwargs,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 514 |
+
if hasattr(self.model, "add_model_tags"):
|
| 515 |
+
self.model.add_model_tags(self._tag_names)
|
| 516 |
+
|
| 517 |
+
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
|
| 518 |
+
"""Creates a model from a path or model identifier."""
|
| 519 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 520 |
+
# Handle torch dtype
|
| 521 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 522 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 523 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 524 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 525 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 526 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 527 |
+
else:
|
| 528 |
+
raise ValueError(
|
| 529 |
+
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
| 530 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 531 |
+
)
|
| 532 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 533 |
+
if args.gradient_checkpointing:
|
| 534 |
+
model_init_kwargs["use_cache"] = False
|
| 535 |
+
|
| 536 |
+
# Create model
|
| 537 |
+
if args.use_liger:
|
| 538 |
+
if not is_liger_kernel_available():
|
| 539 |
+
raise ImportError("Please install Liger-kernel for use_liger=True")
|
| 540 |
+
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
| 541 |
+
else:
|
| 542 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
| 543 |
+
return model
|
| 544 |
+
|
| 545 |
+
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
| 546 |
+
"""Prepares a model for PEFT training."""
|
| 547 |
+
if not is_peft_available():
|
| 548 |
+
raise ImportError("To use PeftModel, you need to install the `peft` library.")
|
| 549 |
+
|
| 550 |
+
if not isinstance(peft_config, PeftConfig):
|
| 551 |
+
raise ValueError(
|
| 552 |
+
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
|
| 553 |
+
"to pass a PeftConfig object to the SFTTrainer."
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
if isinstance(model, PeftModel):
|
| 557 |
+
return model
|
| 558 |
+
|
| 559 |
+
# Handle quantized models (QLoRA)
|
| 560 |
+
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
| 561 |
+
|
| 562 |
+
is_sharded_qlora = False
|
| 563 |
+
if getattr(model, "is_loaded_in_4bit", False):
|
| 564 |
+
# Check if model is sharded (FSDP/DS-Zero3)
|
| 565 |
+
for _, param in model.named_parameters():
|
| 566 |
+
if param.__class__.__name__ == "Params4bit":
|
| 567 |
+
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
| 568 |
+
break
|
| 569 |
+
|
| 570 |
+
# Prepare model for kbit training if needed
|
| 571 |
+
if is_qlora and not is_sharded_qlora:
|
| 572 |
+
model = self._prepare_model_for_kbit_training(model, args)
|
| 573 |
+
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
|
| 574 |
+
args = dataclasses.replace(args, gradient_checkpointing=False)
|
| 575 |
+
elif args.gradient_checkpointing:
|
| 576 |
+
model = self._enable_gradient_checkpointing(model, args)
|
| 577 |
+
|
| 578 |
+
# Create PEFT model
|
| 579 |
+
if (
|
| 580 |
+
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
|
| 581 |
+
and getattr(model, "is_loaded_in_4bit", False)
|
| 582 |
+
and is_sharded_qlora
|
| 583 |
+
):
|
| 584 |
+
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
| 585 |
+
else:
|
| 586 |
+
model = get_peft_model(model, peft_config)
|
| 587 |
+
|
| 588 |
+
# Handle bf16 casting for 4-bit models
|
| 589 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
| 590 |
+
peft_module_casting_to_bf16(model)
|
| 591 |
+
|
| 592 |
+
return model
|
| 593 |
+
|
| 594 |
+
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
| 595 |
+
"""Prepares a quantized model for kbit training."""
|
| 596 |
+
prepare_model_kwargs = {
|
| 597 |
+
"use_gradient_checkpointing": args.gradient_checkpointing,
|
| 598 |
+
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 602 |
+
|
| 603 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
| 604 |
+
"""Enables gradient checkpointing for the model."""
|
| 605 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 606 |
+
use_reentrant = (
|
| 607 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if use_reentrant:
|
| 611 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 612 |
+
model.enable_input_require_grads()
|
| 613 |
+
else:
|
| 614 |
+
|
| 615 |
+
def make_inputs_require_grad(module, input, output):
|
| 616 |
+
output.requires_grad_(True)
|
| 617 |
+
|
| 618 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 619 |
+
|
| 620 |
+
return model
|
| 621 |
+
|
| 622 |
+
def _prepare_dataset(
|
| 623 |
+
self,
|
| 624 |
+
dataset: Union[Dataset, IterableDataset],
|
| 625 |
+
processing_class,
|
| 626 |
+
args,
|
| 627 |
+
packing: bool,
|
| 628 |
+
formatting_func: Optional[Callable[[dict], str]],
|
| 629 |
+
dataset_name: str,
|
| 630 |
+
) -> Union[Dataset, IterableDataset]:
|
| 631 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
| 632 |
+
if isinstance(dataset, ConstantLengthDataset): return dataset
|
| 633 |
+
|
| 634 |
+
map_kwargs = {}
|
| 635 |
+
use_desc = isinstance(dataset, Dataset)
|
| 636 |
+
is_vlm = hasattr(processing_class, "tokenizer")
|
| 637 |
+
tokenizer = processing_class
|
| 638 |
+
if is_vlm: tokenizer = processing_class.tokenizer
|
| 639 |
+
|
| 640 |
+
# Get max length
|
| 641 |
+
max_seq_length = getattr(args, "max_length", 0)
|
| 642 |
+
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
| 643 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
| 644 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
| 645 |
+
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
| 646 |
+
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
| 647 |
+
do_truncation = max_seq_length != 0
|
| 648 |
+
do_formatting_func = False
|
| 649 |
+
do_tokenize = True
|
| 650 |
+
|
| 651 |
+
# Get correct column names
|
| 652 |
+
column_names = set(next(iter(dataset)).keys())
|
| 653 |
+
used_column_names = ["input_ids"]
|
| 654 |
+
if "attention_mask" in column_names:
|
| 655 |
+
used_column_names.append("attention_mask")
|
| 656 |
+
|
| 657 |
+
# Check if already tokenized so skip
|
| 658 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 659 |
+
if "labels" in column_names:
|
| 660 |
+
# Most likely forgot data collator!
|
| 661 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 662 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 663 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 664 |
+
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
| 665 |
+
used_column_names.append("labels")
|
| 666 |
+
do_tokenize = False
|
| 667 |
+
elif "input_ids" in column_names:
|
| 668 |
+
# Skip dataset prep, and set data collator
|
| 669 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 670 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 671 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 672 |
+
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 673 |
+
do_tokenize = False
|
| 674 |
+
elif dataset_text_field not in column_names:
|
| 675 |
+
do_formatting_func = True
|
| 676 |
+
if formatting_func is None:
|
| 677 |
+
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
| 678 |
+
pass
|
| 679 |
+
|
| 680 |
+
if do_tokenize:
|
| 681 |
+
# Check double BOS tokens
|
| 682 |
+
if do_formatting_func:
|
| 683 |
+
test_text = formatting_func(next(iter(dataset)))
|
| 684 |
+
if not isinstance(test_text, list):
|
| 685 |
+
raise ValueError(
|
| 686 |
+
"Unsloth: The `formatting_func` should return a list of processed strings."
|
| 687 |
+
)
|
| 688 |
+
test_text = test_text[0]
|
| 689 |
+
else:
|
| 690 |
+
test_text = next(iter(dataset))[dataset_text_field][0]
|
| 691 |
+
|
| 692 |
+
# Get chat template
|
| 693 |
+
chat_template = getattr(processing_class, 'chat_template', '')
|
| 694 |
+
if chat_template == '' and is_vlm:
|
| 695 |
+
chat_template = getattr(tokenizer, 'chat_template', '')
|
| 696 |
+
if chat_template is None:
|
| 697 |
+
chat_template = ''
|
| 698 |
+
|
| 699 |
+
# Get bos_token
|
| 700 |
+
add_special_tokens = True
|
| 701 |
+
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
| 702 |
+
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
| 703 |
+
bos_token = bos_token_1 or bos_token_2
|
| 704 |
+
|
| 705 |
+
if bos_token is not None:
|
| 706 |
+
if test_text.startswith(bos_token) or bos_token in chat_template:
|
| 707 |
+
add_special_tokens = False
|
| 708 |
+
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
| 709 |
+
pass
|
| 710 |
+
|
| 711 |
+
# Create tokenize function
|
| 712 |
+
def _tokenize(example):
|
| 713 |
+
return tokenizer(
|
| 714 |
+
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
| 715 |
+
truncation = do_truncation,
|
| 716 |
+
max_length = max_seq_length,
|
| 717 |
+
return_token_type_ids = False,
|
| 718 |
+
add_special_tokens = add_special_tokens,
|
| 719 |
+
)
|
| 720 |
+
pass
|
| 721 |
+
|
| 722 |
+
if not isinstance(dataset, IterableDataset):
|
| 723 |
+
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
| 724 |
+
else:
|
| 725 |
+
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
|
| 726 |
+
|
| 727 |
+
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
| 728 |
+
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
| 729 |
+
|
| 730 |
+
# If VLM, switch data collator since .pad is needed!
|
| 731 |
+
if is_vlm and not hasattr(processing_class, "pad"):
|
| 732 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 733 |
+
self.data_collator = data_collator
|
| 734 |
+
pass
|
| 735 |
+
pass
|
| 736 |
+
if packing:
|
| 737 |
+
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
| 738 |
+
return dataset
|
| 739 |
+
|
| 740 |
+
if max_seq_length == 0:
|
| 741 |
+
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
| 742 |
+
|
| 743 |
+
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
| 744 |
+
dataset = dataset.select_columns(used_column_names).map(
|
| 745 |
+
pack_examples,
|
| 746 |
+
batched = True,
|
| 747 |
+
fn_kwargs = {"seq_length": max_seq_length,},
|
| 748 |
+
**map_kwargs,
|
| 749 |
+
)
|
| 750 |
+
pass
|
| 751 |
+
return dataset
|
| 752 |
+
|
| 753 |
+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
| 754 |
+
outputs = super().compute_loss(
|
| 755 |
+
model,
|
| 756 |
+
inputs,
|
| 757 |
+
return_outputs = return_outputs,
|
| 758 |
+
num_items_in_batch = num_items_in_batch,
|
| 759 |
+
)
|
| 760 |
+
return outputs
|
| 761 |
+
|
| 762 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 763 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 764 |
+
|
| 765 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 766 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 767 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 768 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 769 |
+
|
| 770 |
+
logs = {**logs, **metrics}
|
| 771 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 772 |
+
super().log(logs, start_time)
|
| 773 |
+
else: # transformers<=4.46
|
| 774 |
+
super().log(logs)
|
| 775 |
+
self._metrics.clear()
|
| 776 |
+
|
| 777 |
+
def create_model_card(
|
| 778 |
+
self,
|
| 779 |
+
model_name: Optional[str] = None,
|
| 780 |
+
dataset_name: Optional[str] = None,
|
| 781 |
+
tags: Union[str, list[str], None] = None,
|
| 782 |
+
):
|
| 783 |
+
"""
|
| 784 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 788 |
+
Name of the model.
|
| 789 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 790 |
+
Name of the dataset used for training.
|
| 791 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 792 |
+
Tags to be associated with the model card.
|
| 793 |
+
"""
|
| 794 |
+
if not self.is_world_process_zero():
|
| 795 |
+
return
|
| 796 |
+
|
| 797 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 798 |
+
base_model = self.model.config._name_or_path
|
| 799 |
+
else:
|
| 800 |
+
base_model = None
|
| 801 |
+
|
| 802 |
+
tags = tags or []
|
| 803 |
+
if isinstance(tags, str):
|
| 804 |
+
tags = [tags]
|
| 805 |
+
|
| 806 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 807 |
+
tags.append("unsloth")
|
| 808 |
+
|
| 809 |
+
model_card = generate_model_card(
|
| 810 |
+
base_model=base_model,
|
| 811 |
+
model_name=model_name,
|
| 812 |
+
hub_model_id=self.hub_model_id,
|
| 813 |
+
dataset_name=dataset_name,
|
| 814 |
+
tags=tags,
|
| 815 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 816 |
+
comet_url=get_comet_experiment_url(),
|
| 817 |
+
trainer_name="SFT",
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 821 |
+
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
| 822 |
+
"""
|
| 823 |
+
|
| 824 |
+
Trainer for Supervised Fine-Tuning (SFT) method.
|
| 825 |
+
|
| 826 |
+
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
| 827 |
+
|
| 828 |
+
Example:
|
| 829 |
+
|
| 830 |
+
```python
|
| 831 |
+
from datasets import load_dataset
|
| 832 |
+
from trl import SFTTrainer
|
| 833 |
+
|
| 834 |
+
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
| 835 |
+
|
| 836 |
+
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
| 837 |
+
trainer.train()
|
| 838 |
+
```
|
| 839 |
+
|
| 840 |
+
Args:
|
| 841 |
+
model (`Union[str, PreTrainedModel]`):
|
| 842 |
+
Model to be trained. Can be either:
|
| 843 |
+
|
| 844 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 845 |
+
a path to a *directory* containing model weights saved using
|
| 846 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 847 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 848 |
+
in `args.model_init_kwargs`.
|
| 849 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 850 |
+
args ([`SFTConfig`], *optional*, defaults to `None`):
|
| 851 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 852 |
+
data_collator (`DataCollator`, *optional*):
|
| 853 |
+
Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
|
| 854 |
+
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
| 855 |
+
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
| 856 |
+
tokenizer.
|
| 857 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 858 |
+
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
| 859 |
+
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
| 860 |
+
|
| 861 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 862 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 863 |
+
and content).
|
| 864 |
+
|
| 865 |
+
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
| 866 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 867 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 868 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 869 |
+
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
| 870 |
+
with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 871 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 872 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 873 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 874 |
+
|
| 875 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 876 |
+
method.
|
| 877 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 878 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 879 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 880 |
+
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
|
| 881 |
+
A tuple containing the optimizer class and keyword arguments to use.
|
| 882 |
+
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
|
| 883 |
+
|
| 884 |
+
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
|
| 885 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
|
| 886 |
+
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
| 887 |
+
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
| 888 |
+
by this function will be reflected in the predictions received by `compute_metrics`.
|
| 889 |
+
|
| 890 |
+
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
| 891 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 892 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 893 |
+
formatting_func (`Optional[Callable]`):
|
| 894 |
+
Formatting function applied to the dataset before tokenization.
|
| 895 |
+
|
| 896 |
+
"""
|
| 897 |
+
def __init__(
|
| 898 |
+
self,
|
| 899 |
+
model,
|
| 900 |
+
args = None,
|
| 901 |
+
data_collator = None,
|
| 902 |
+
train_dataset = None,
|
| 903 |
+
eval_dataset = None,
|
| 904 |
+
processing_class = None,
|
| 905 |
+
compute_loss_func = None,
|
| 906 |
+
compute_metrics = None,
|
| 907 |
+
callbacks = None,
|
| 908 |
+
optimizer_cls_and_kwargs = None,
|
| 909 |
+
preprocess_logits_for_metrics = None,
|
| 910 |
+
peft_config = None,
|
| 911 |
+
formatting_func = None,
|
| 912 |
+
**kwargs
|
| 913 |
+
):
|
| 914 |
+
if args is None: args = UnslothSFTConfig()
|
| 915 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 916 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 917 |
+
force_float32 = False
|
| 918 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 919 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 920 |
+
force_float32 = True
|
| 921 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 922 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 923 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 924 |
+
from unsloth_zoo.utils import _get_dtype
|
| 925 |
+
dtype = _get_dtype(dtype)
|
| 926 |
+
float16 = dtype == torch.float16
|
| 927 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 928 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 929 |
+
if force_float32:
|
| 930 |
+
args.fp16 = False
|
| 931 |
+
args.bf16 = False
|
| 932 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 933 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 934 |
+
args.fp16 = float16
|
| 935 |
+
args.bf16 = not float16
|
| 936 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 937 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 938 |
+
args.eval_strategy = 'steps'
|
| 939 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 940 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 941 |
+
if ga_steps is not None and ga_steps > 1:
|
| 942 |
+
from transformers import __version__ as transformers_version
|
| 943 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 944 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 945 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 946 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 947 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 948 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 949 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 950 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 951 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 952 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 953 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 954 |
+
if force_float32:
|
| 955 |
+
args.bf16_full_eval = False
|
| 956 |
+
args.fp16_full_eval = False
|
| 957 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 958 |
+
args.bf16_full_eval = True
|
| 959 |
+
args.fp16_full_eval = False
|
| 960 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 961 |
+
args.bf16_full_eval = args.bf16
|
| 962 |
+
args.fp16_full_eval = args.fp16
|
| 963 |
+
_output_logits = False
|
| 964 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 965 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 966 |
+
if _output_logits:
|
| 967 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 968 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 969 |
+
pass
|
| 970 |
+
else:
|
| 971 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 972 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 973 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 974 |
+
max_seq_length = model.max_seq_length
|
| 975 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 976 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 977 |
+
model.for_training()
|
| 978 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 979 |
+
if 'processing_class' in locals():
|
| 980 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 981 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 982 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 983 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 984 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 985 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 986 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 987 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 988 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 989 |
+
else:
|
| 990 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 991 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 992 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 993 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 994 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 995 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 996 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 997 |
+
else:
|
| 998 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 999 |
+
other_metrics = []
|
| 1000 |
+
|
| 1001 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1002 |
+
PatchRLStatistics('sft_trainer', other_metrics)
|
| 1003 |
+
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
| 1004 |
+
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
| 1005 |
+
from unsloth_zoo.training_utils import fix_zero_training_loss
|
| 1006 |
+
if 'tokenizer' not in locals(): tokenizer = processing_class
|
| 1007 |
+
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
| 1008 |
+
fix_zero_training_loss(model, tokenizer, train_dataset)
|
| 1009 |
+
|
| 1010 |
+
super().__init__(
|
| 1011 |
+
model = model,
|
| 1012 |
+
args = args,
|
| 1013 |
+
data_collator = data_collator,
|
| 1014 |
+
train_dataset = train_dataset,
|
| 1015 |
+
eval_dataset = eval_dataset,
|
| 1016 |
+
processing_class = processing_class,
|
| 1017 |
+
compute_loss_func = compute_loss_func,
|
| 1018 |
+
compute_metrics = compute_metrics,
|
| 1019 |
+
callbacks = callbacks,
|
| 1020 |
+
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
| 1021 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1022 |
+
peft_config = peft_config,
|
| 1023 |
+
formatting_func = formatting_func,**kwargs)
|
| 1024 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1025 |
+
self.neftune_hook_handle.remove()
|
| 1026 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1027 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1028 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1029 |
+
pass
|
| 1030 |
+
|
| 1031 |
+
pass
|
unsloth_compiled_cache/UnslothXPOTrainer.py
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.3
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothXPOConfig(XPOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`XPOTrainer`].
|
| 47 |
+
|
| 48 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 49 |
+
|
| 50 |
+
Parameters:
|
| 51 |
+
alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
|
| 52 |
+
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
|
| 53 |
+
and the last alpha is used for the rest of the epochs.
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 57 |
+
default = None,
|
| 58 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 59 |
+
)
|
| 60 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 61 |
+
default = -1,
|
| 62 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 63 |
+
)
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
output_dir = None,
|
| 67 |
+
overwrite_output_dir = None,
|
| 68 |
+
do_train = False,
|
| 69 |
+
do_eval = False,
|
| 70 |
+
do_predict = False,
|
| 71 |
+
eval_strategy = 'no',
|
| 72 |
+
prediction_loss_only = False,
|
| 73 |
+
per_device_train_batch_size = 4,
|
| 74 |
+
per_device_eval_batch_size = 4,
|
| 75 |
+
per_gpu_train_batch_size = None,
|
| 76 |
+
per_gpu_eval_batch_size = None,
|
| 77 |
+
gradient_accumulation_steps = 2,
|
| 78 |
+
eval_accumulation_steps = 2,
|
| 79 |
+
eval_delay = 0,
|
| 80 |
+
torch_empty_cache_steps = 250,
|
| 81 |
+
learning_rate = 5e-05,
|
| 82 |
+
weight_decay = 0.01,
|
| 83 |
+
adam_beta1 = 0.9,
|
| 84 |
+
adam_beta2 = 0.999,
|
| 85 |
+
adam_epsilon = 1e-08,
|
| 86 |
+
max_grad_norm = 1.0,
|
| 87 |
+
num_train_epochs = 3.0,
|
| 88 |
+
max_steps = -1,
|
| 89 |
+
lr_scheduler_type = 'linear',
|
| 90 |
+
warmup_ratio = 0.1,
|
| 91 |
+
warmup_steps = 0,
|
| 92 |
+
log_level = 'passive',
|
| 93 |
+
log_level_replica = 'warning',
|
| 94 |
+
log_on_each_node = True,
|
| 95 |
+
logging_dir = None,
|
| 96 |
+
logging_strategy = 'steps',
|
| 97 |
+
logging_first_step = False,
|
| 98 |
+
logging_steps = 1,
|
| 99 |
+
logging_nan_inf_filter = False,
|
| 100 |
+
save_strategy = 'steps',
|
| 101 |
+
save_steps = 500,
|
| 102 |
+
save_total_limit = None,
|
| 103 |
+
save_safetensors = True,
|
| 104 |
+
save_on_each_node = False,
|
| 105 |
+
save_only_model = False,
|
| 106 |
+
restore_callback_states_from_checkpoint = False,
|
| 107 |
+
no_cuda = False,
|
| 108 |
+
use_cpu = False,
|
| 109 |
+
use_mps_device = False,
|
| 110 |
+
seed = 3407,
|
| 111 |
+
data_seed = 3407,
|
| 112 |
+
jit_mode_eval = False,
|
| 113 |
+
use_ipex = False,
|
| 114 |
+
bf16 = False,
|
| 115 |
+
fp16 = False,
|
| 116 |
+
fp16_opt_level = 'O1',
|
| 117 |
+
half_precision_backend = 'auto',
|
| 118 |
+
bf16_full_eval = False,
|
| 119 |
+
fp16_full_eval = False,
|
| 120 |
+
tf32 = None,
|
| 121 |
+
local_rank = -1,
|
| 122 |
+
ddp_backend = None,
|
| 123 |
+
tpu_num_cores = None,
|
| 124 |
+
tpu_metrics_debug = False,
|
| 125 |
+
debug = '',
|
| 126 |
+
dataloader_drop_last = False,
|
| 127 |
+
eval_steps = None,
|
| 128 |
+
dataloader_num_workers = 0,
|
| 129 |
+
dataloader_prefetch_factor = None,
|
| 130 |
+
past_index = -1,
|
| 131 |
+
run_name = None,
|
| 132 |
+
disable_tqdm = None,
|
| 133 |
+
remove_unused_columns = True,
|
| 134 |
+
label_names = None,
|
| 135 |
+
load_best_model_at_end = False,
|
| 136 |
+
metric_for_best_model = None,
|
| 137 |
+
greater_is_better = None,
|
| 138 |
+
ignore_data_skip = False,
|
| 139 |
+
fsdp = '',
|
| 140 |
+
fsdp_min_num_params = 0,
|
| 141 |
+
fsdp_config = None,
|
| 142 |
+
tp_size = 0,
|
| 143 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 144 |
+
accelerator_config = None,
|
| 145 |
+
deepspeed = None,
|
| 146 |
+
label_smoothing_factor = 0.0,
|
| 147 |
+
optim = 'adamw_8bit',
|
| 148 |
+
optim_args = None,
|
| 149 |
+
adafactor = False,
|
| 150 |
+
group_by_length = False,
|
| 151 |
+
length_column_name = 'length',
|
| 152 |
+
report_to = None,
|
| 153 |
+
ddp_find_unused_parameters = None,
|
| 154 |
+
ddp_bucket_cap_mb = None,
|
| 155 |
+
ddp_broadcast_buffers = None,
|
| 156 |
+
dataloader_pin_memory = True,
|
| 157 |
+
dataloader_persistent_workers = False,
|
| 158 |
+
skip_memory_metrics = True,
|
| 159 |
+
use_legacy_prediction_loop = False,
|
| 160 |
+
push_to_hub = False,
|
| 161 |
+
resume_from_checkpoint = None,
|
| 162 |
+
hub_model_id = None,
|
| 163 |
+
hub_strategy = 'every_save',
|
| 164 |
+
hub_token = None,
|
| 165 |
+
hub_private_repo = None,
|
| 166 |
+
hub_always_push = False,
|
| 167 |
+
gradient_checkpointing = False,
|
| 168 |
+
gradient_checkpointing_kwargs = None,
|
| 169 |
+
include_inputs_for_metrics = False,
|
| 170 |
+
eval_do_concat_batches = True,
|
| 171 |
+
fp16_backend = 'auto',
|
| 172 |
+
evaluation_strategy = None,
|
| 173 |
+
push_to_hub_model_id = None,
|
| 174 |
+
push_to_hub_organization = None,
|
| 175 |
+
push_to_hub_token = None,
|
| 176 |
+
mp_parameters = '',
|
| 177 |
+
auto_find_batch_size = False,
|
| 178 |
+
full_determinism = False,
|
| 179 |
+
torchdynamo = None,
|
| 180 |
+
ray_scope = 'last',
|
| 181 |
+
ddp_timeout = 1800,
|
| 182 |
+
torch_compile = False,
|
| 183 |
+
torch_compile_backend = None,
|
| 184 |
+
torch_compile_mode = None,
|
| 185 |
+
dispatch_batches = None,
|
| 186 |
+
split_batches = None,
|
| 187 |
+
include_tokens_per_second = False,
|
| 188 |
+
include_num_input_tokens_seen = False,
|
| 189 |
+
neftune_noise_alpha = None,
|
| 190 |
+
optim_target_modules = None,
|
| 191 |
+
batch_eval_metrics = False,
|
| 192 |
+
eval_on_start = False,
|
| 193 |
+
use_liger_kernel = False,
|
| 194 |
+
eval_use_gather_object = False,
|
| 195 |
+
average_tokens_across_devices = False,
|
| 196 |
+
reward_model_path = None,
|
| 197 |
+
judge = None,
|
| 198 |
+
max_new_tokens = 64,
|
| 199 |
+
max_length = 512,
|
| 200 |
+
temperature = 0.9,
|
| 201 |
+
missing_eos_penalty = None,
|
| 202 |
+
loss_type = 'sigmoid',
|
| 203 |
+
dataset_num_proc = None,
|
| 204 |
+
disable_dropout = True,
|
| 205 |
+
use_vllm = False,
|
| 206 |
+
ds3_gather_for_generation = True,
|
| 207 |
+
vllm_sampling_params = None,
|
| 208 |
+
unsloth_num_chunks = -1,
|
| 209 |
+
**kwargs,
|
| 210 |
+
):
|
| 211 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 212 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 213 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 214 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 215 |
+
save_strategy = 'no'
|
| 216 |
+
if dataset_num_proc is None:
|
| 217 |
+
from multiprocessing import cpu_count
|
| 218 |
+
dataset_num_proc = cpu_count()
|
| 219 |
+
|
| 220 |
+
super().__init__(
|
| 221 |
+
output_dir = output_dir,
|
| 222 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 223 |
+
do_train = do_train,
|
| 224 |
+
do_eval = do_eval,
|
| 225 |
+
do_predict = do_predict,
|
| 226 |
+
eval_strategy = eval_strategy,
|
| 227 |
+
prediction_loss_only = prediction_loss_only,
|
| 228 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 229 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 230 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 231 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 232 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 233 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 234 |
+
eval_delay = eval_delay,
|
| 235 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 236 |
+
learning_rate = learning_rate,
|
| 237 |
+
weight_decay = weight_decay,
|
| 238 |
+
adam_beta1 = adam_beta1,
|
| 239 |
+
adam_beta2 = adam_beta2,
|
| 240 |
+
adam_epsilon = adam_epsilon,
|
| 241 |
+
max_grad_norm = max_grad_norm,
|
| 242 |
+
num_train_epochs = num_train_epochs,
|
| 243 |
+
max_steps = max_steps,
|
| 244 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 245 |
+
warmup_ratio = warmup_ratio,
|
| 246 |
+
warmup_steps = warmup_steps,
|
| 247 |
+
log_level = log_level,
|
| 248 |
+
log_level_replica = log_level_replica,
|
| 249 |
+
log_on_each_node = log_on_each_node,
|
| 250 |
+
logging_dir = logging_dir,
|
| 251 |
+
logging_strategy = logging_strategy,
|
| 252 |
+
logging_first_step = logging_first_step,
|
| 253 |
+
logging_steps = logging_steps,
|
| 254 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 255 |
+
save_strategy = save_strategy,
|
| 256 |
+
save_steps = save_steps,
|
| 257 |
+
save_total_limit = save_total_limit,
|
| 258 |
+
save_safetensors = save_safetensors,
|
| 259 |
+
save_on_each_node = save_on_each_node,
|
| 260 |
+
save_only_model = save_only_model,
|
| 261 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 262 |
+
no_cuda = no_cuda,
|
| 263 |
+
use_cpu = use_cpu,
|
| 264 |
+
use_mps_device = use_mps_device,
|
| 265 |
+
seed = seed,
|
| 266 |
+
data_seed = data_seed,
|
| 267 |
+
jit_mode_eval = jit_mode_eval,
|
| 268 |
+
use_ipex = use_ipex,
|
| 269 |
+
bf16 = bf16,
|
| 270 |
+
fp16 = fp16,
|
| 271 |
+
fp16_opt_level = fp16_opt_level,
|
| 272 |
+
half_precision_backend = half_precision_backend,
|
| 273 |
+
bf16_full_eval = bf16_full_eval,
|
| 274 |
+
fp16_full_eval = fp16_full_eval,
|
| 275 |
+
tf32 = tf32,
|
| 276 |
+
local_rank = local_rank,
|
| 277 |
+
ddp_backend = ddp_backend,
|
| 278 |
+
tpu_num_cores = tpu_num_cores,
|
| 279 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 280 |
+
debug = debug,
|
| 281 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 282 |
+
eval_steps = eval_steps,
|
| 283 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 284 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 285 |
+
past_index = past_index,
|
| 286 |
+
run_name = run_name,
|
| 287 |
+
disable_tqdm = disable_tqdm,
|
| 288 |
+
remove_unused_columns = remove_unused_columns,
|
| 289 |
+
label_names = label_names,
|
| 290 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 291 |
+
metric_for_best_model = metric_for_best_model,
|
| 292 |
+
greater_is_better = greater_is_better,
|
| 293 |
+
ignore_data_skip = ignore_data_skip,
|
| 294 |
+
fsdp = fsdp,
|
| 295 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 296 |
+
fsdp_config = fsdp_config,
|
| 297 |
+
tp_size = tp_size,
|
| 298 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 299 |
+
accelerator_config = accelerator_config,
|
| 300 |
+
deepspeed = deepspeed,
|
| 301 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 302 |
+
optim = optim,
|
| 303 |
+
optim_args = optim_args,
|
| 304 |
+
adafactor = adafactor,
|
| 305 |
+
group_by_length = group_by_length,
|
| 306 |
+
length_column_name = length_column_name,
|
| 307 |
+
report_to = report_to,
|
| 308 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 309 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 310 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 311 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 312 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 313 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 314 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 315 |
+
push_to_hub = push_to_hub,
|
| 316 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 317 |
+
hub_model_id = hub_model_id,
|
| 318 |
+
hub_strategy = hub_strategy,
|
| 319 |
+
hub_token = hub_token,
|
| 320 |
+
hub_private_repo = hub_private_repo,
|
| 321 |
+
hub_always_push = hub_always_push,
|
| 322 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 323 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 324 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 325 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 326 |
+
fp16_backend = fp16_backend,
|
| 327 |
+
evaluation_strategy = evaluation_strategy,
|
| 328 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 329 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 330 |
+
push_to_hub_token = push_to_hub_token,
|
| 331 |
+
mp_parameters = mp_parameters,
|
| 332 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 333 |
+
full_determinism = full_determinism,
|
| 334 |
+
torchdynamo = torchdynamo,
|
| 335 |
+
ray_scope = ray_scope,
|
| 336 |
+
ddp_timeout = ddp_timeout,
|
| 337 |
+
torch_compile = torch_compile,
|
| 338 |
+
torch_compile_backend = torch_compile_backend,
|
| 339 |
+
torch_compile_mode = torch_compile_mode,
|
| 340 |
+
dispatch_batches = dispatch_batches,
|
| 341 |
+
split_batches = split_batches,
|
| 342 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 343 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 344 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 345 |
+
optim_target_modules = optim_target_modules,
|
| 346 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 347 |
+
eval_on_start = eval_on_start,
|
| 348 |
+
use_liger_kernel = use_liger_kernel,
|
| 349 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 350 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 351 |
+
reward_model_path = reward_model_path,
|
| 352 |
+
judge = judge,
|
| 353 |
+
max_new_tokens = max_new_tokens,
|
| 354 |
+
max_length = max_length,
|
| 355 |
+
temperature = temperature,
|
| 356 |
+
missing_eos_penalty = missing_eos_penalty,
|
| 357 |
+
loss_type = loss_type,
|
| 358 |
+
dataset_num_proc = dataset_num_proc,
|
| 359 |
+
disable_dropout = disable_dropout,
|
| 360 |
+
use_vllm = use_vllm,
|
| 361 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 362 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 363 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 364 |
+
pass
|
| 365 |
+
|
| 366 |
+
class _UnslothXPOTrainer(OnlineDPOTrainer):
|
| 367 |
+
r""""""
|
| 368 |
+
|
| 369 |
+
_tag_names = ["trl", "xpo"]
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 374 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 375 |
+
reward_model: Optional[nn.Module] = None,
|
| 376 |
+
judge: Optional[BasePairwiseJudge] = None,
|
| 377 |
+
args: Optional[XPOConfig] = None,
|
| 378 |
+
data_collator: Optional[Callable] = None,
|
| 379 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 380 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 381 |
+
processing_class: Optional[
|
| 382 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 383 |
+
] = None,
|
| 384 |
+
peft_config: Optional[dict] = None,
|
| 385 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 386 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 387 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 388 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 389 |
+
) -> None:
|
| 390 |
+
super().__init__(
|
| 391 |
+
model=model,
|
| 392 |
+
ref_model=ref_model,
|
| 393 |
+
judge=judge,
|
| 394 |
+
reward_model=reward_model,
|
| 395 |
+
args=args,
|
| 396 |
+
data_collator=data_collator,
|
| 397 |
+
train_dataset=train_dataset,
|
| 398 |
+
eval_dataset=eval_dataset,
|
| 399 |
+
processing_class=processing_class,
|
| 400 |
+
reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
|
| 401 |
+
peft_config=peft_config,
|
| 402 |
+
compute_metrics=compute_metrics,
|
| 403 |
+
callbacks=callbacks,
|
| 404 |
+
optimizers=optimizers,
|
| 405 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self._alpha = self.args.alpha
|
| 409 |
+
|
| 410 |
+
# Overwrite the stats dictionary to include XPO specific statistics
|
| 411 |
+
self.stats = {
|
| 412 |
+
# Remove "non_score_reward", "rlhf_reward", "scores"
|
| 413 |
+
# Add "loss/dpo", "loss/xpo"
|
| 414 |
+
"loss/dpo": [],
|
| 415 |
+
"loss/xpo": [],
|
| 416 |
+
"objective/kl": [],
|
| 417 |
+
"objective/entropy": [],
|
| 418 |
+
"rewards/chosen": [],
|
| 419 |
+
"rewards/rejected": [],
|
| 420 |
+
"rewards/accuracies": [],
|
| 421 |
+
"rewards/margins": [],
|
| 422 |
+
"logps/chosen": [],
|
| 423 |
+
"logps/rejected": [],
|
| 424 |
+
# Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
|
| 425 |
+
"val/model_contain_eos_token": [],
|
| 426 |
+
"val/ref_contain_eos_token": [],
|
| 427 |
+
"alpha": [],
|
| 428 |
+
"beta": [],
|
| 429 |
+
}
|
| 430 |
+
if self.reward_model is not None:
|
| 431 |
+
# Replace "scores" by "model_scores" and "ref_scores"
|
| 432 |
+
self.stats["objective/model_scores"] = []
|
| 433 |
+
self.stats["objective/ref_scores"] = []
|
| 434 |
+
self.stats["objective/scores_margin"] = []
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def alpha(self):
|
| 438 |
+
if isinstance(self._alpha, list):
|
| 439 |
+
epoch = self.state.epoch
|
| 440 |
+
return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
|
| 441 |
+
else:
|
| 442 |
+
return self._alpha
|
| 443 |
+
|
| 444 |
+
def _generate_completions(self, prompts, model):
|
| 445 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 446 |
+
model_output = unwrapped_model.generate(
|
| 447 |
+
input_ids=prompts["input_ids"],
|
| 448 |
+
attention_mask=prompts["attention_mask"],
|
| 449 |
+
generation_config=self.generation_config,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
ref_model = model if self.ref_model is None else self.ref_model
|
| 453 |
+
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
| 454 |
+
ref_output = unwrapped_ref_model.generate(
|
| 455 |
+
input_ids=prompts["input_ids"],
|
| 456 |
+
attention_mask=prompts["attention_mask"],
|
| 457 |
+
generation_config=self.generation_config,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return model_output, ref_output
|
| 461 |
+
|
| 462 |
+
def _process_completions(self, model_output, ref_output, prompts):
|
| 463 |
+
context_length = prompts["input_ids"].shape[1]
|
| 464 |
+
|
| 465 |
+
# Process model completions
|
| 466 |
+
model_completion_ids = model_output[:, context_length:]
|
| 467 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
| 468 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 469 |
+
)
|
| 470 |
+
model_data = {
|
| 471 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 472 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 473 |
+
"raw": prompts["raw"],
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
# Process reference model completions
|
| 477 |
+
ref_completion_ids = ref_output[:, context_length:]
|
| 478 |
+
ref_completion_ids, ref_completion_mask = truncate_right(
|
| 479 |
+
ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 480 |
+
)
|
| 481 |
+
ref_data = {
|
| 482 |
+
"input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
|
| 483 |
+
"attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
|
| 484 |
+
"raw": prompts["raw"],
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
return model_data, ref_data
|
| 488 |
+
|
| 489 |
+
def _compute_rewards(self, model_data, ref_data, context_length):
|
| 490 |
+
with torch.no_grad():
|
| 491 |
+
_, model_scores, _ = get_reward(
|
| 492 |
+
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 493 |
+
)
|
| 494 |
+
_, ref_scores, _ = get_reward(
|
| 495 |
+
self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Apply EOS penalty if needed
|
| 499 |
+
if self.args.missing_eos_penalty is not None:
|
| 500 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 501 |
+
ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 502 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 503 |
+
ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
|
| 504 |
+
|
| 505 |
+
return model_scores, ref_scores
|
| 506 |
+
|
| 507 |
+
def _compute_judge(self, model_data, ref_data, context_length):
|
| 508 |
+
prompts = model_data["raw"]
|
| 509 |
+
model_data_completions = self.processing_class.batch_decode(
|
| 510 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 511 |
+
)
|
| 512 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 513 |
+
|
| 514 |
+
ref_data_completions = self.processing_class.batch_decode(
|
| 515 |
+
ref_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 516 |
+
)
|
| 517 |
+
ref_data_completions = [completion.strip() for completion in ref_data_completions]
|
| 518 |
+
|
| 519 |
+
if is_conversational({"prompt": prompts[0]}):
|
| 520 |
+
model_data_completions = [
|
| 521 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 522 |
+
]
|
| 523 |
+
environment = jinja2.Environment()
|
| 524 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 525 |
+
prompts = [template.render(messages=message) for message in prompts]
|
| 526 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 527 |
+
|
| 528 |
+
ref_data_completions = [
|
| 529 |
+
[{"role": "assistant", "content": completion}] for completion in ref_data_completions
|
| 530 |
+
]
|
| 531 |
+
ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
|
| 532 |
+
|
| 533 |
+
ranks_of_first_completion = self.judge.judge(
|
| 534 |
+
prompts,
|
| 535 |
+
list(zip(model_data_completions, ref_data_completions)),
|
| 536 |
+
)
|
| 537 |
+
# convert ranks to a True/False mask:
|
| 538 |
+
# when rank == 0, it means the first completion is the best
|
| 539 |
+
# when rank == 1, it means the second completion is the best
|
| 540 |
+
return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
|
| 541 |
+
|
| 542 |
+
def _compute_logprobs(self, model, model_data, ref_data, context_length):
|
| 543 |
+
def compute_logprobs_for_data(m, data):
|
| 544 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 545 |
+
logits = output.logits[:, context_length - 1 : -1]
|
| 546 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 547 |
+
return token_logprobs
|
| 548 |
+
|
| 549 |
+
# Compute logprobs for model completions
|
| 550 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 551 |
+
# Compute logprobs for model on reference completions (for XPO loss)
|
| 552 |
+
model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 553 |
+
|
| 554 |
+
# Compute logprobs for reference model completions
|
| 555 |
+
with torch.no_grad():
|
| 556 |
+
if self.ref_model is None:
|
| 557 |
+
with model.disable_adapter():
|
| 558 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 559 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 560 |
+
else:
|
| 561 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 562 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
|
| 563 |
+
|
| 564 |
+
# Mask padding tokens
|
| 565 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 566 |
+
ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
|
| 567 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 568 |
+
model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 569 |
+
ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 570 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 571 |
+
|
| 572 |
+
return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
|
| 573 |
+
|
| 574 |
+
def _compute_losses(
|
| 575 |
+
self,
|
| 576 |
+
model_logprobs_model_data,
|
| 577 |
+
model_logprobs_ref_data,
|
| 578 |
+
ref_logprobs_ref_data,
|
| 579 |
+
ref_logprobs_model_data,
|
| 580 |
+
chosen_mask,
|
| 581 |
+
):
|
| 582 |
+
# Compute log probs
|
| 583 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 584 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 585 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 586 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 587 |
+
|
| 588 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 589 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 590 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 591 |
+
|
| 592 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 593 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 594 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 595 |
+
|
| 596 |
+
# Compute logits as the difference between chosen and rejected log ratios
|
| 597 |
+
logits = chosen_log_ratios - rejected_log_ratios
|
| 598 |
+
|
| 599 |
+
if self.args.loss_type == "sigmoid":
|
| 600 |
+
dpo_losses = -F.logsigmoid(self.beta * logits)
|
| 601 |
+
elif self.args.loss_type == "ipo":
|
| 602 |
+
dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 603 |
+
else:
|
| 604 |
+
raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
|
| 605 |
+
|
| 606 |
+
# Compute XPO specific loss
|
| 607 |
+
xpo_losses = self.alpha * model_logprobs_ref_data_sum
|
| 608 |
+
|
| 609 |
+
# Total loss
|
| 610 |
+
loss = (dpo_losses + xpo_losses).mean()
|
| 611 |
+
|
| 612 |
+
return loss, dpo_losses, xpo_losses
|
| 613 |
+
|
| 614 |
+
def _log_statistics(
|
| 615 |
+
self,
|
| 616 |
+
model_data,
|
| 617 |
+
ref_data,
|
| 618 |
+
model_logprobs_model_data,
|
| 619 |
+
model_logprobs_ref_data,
|
| 620 |
+
ref_logprobs_ref_data,
|
| 621 |
+
ref_logprobs_model_data,
|
| 622 |
+
chosen_mask,
|
| 623 |
+
dpo_losses,
|
| 624 |
+
xpo_losses,
|
| 625 |
+
context_length,
|
| 626 |
+
model_scores=None,
|
| 627 |
+
ref_scores=None,
|
| 628 |
+
):
|
| 629 |
+
# Helper function to gather and compute mean
|
| 630 |
+
def gather_mean(tensor):
|
| 631 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 632 |
+
|
| 633 |
+
# Log losses
|
| 634 |
+
self.stats["loss/dpo"].append(gather_mean(dpo_losses))
|
| 635 |
+
self.stats["loss/xpo"].append(gather_mean(xpo_losses))
|
| 636 |
+
|
| 637 |
+
# Log scores
|
| 638 |
+
if self.reward_model is not None:
|
| 639 |
+
self.stats["objective/model_scores"].append(gather_mean(model_scores))
|
| 640 |
+
self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
|
| 641 |
+
self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
|
| 642 |
+
|
| 643 |
+
# Log logprobs
|
| 644 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 645 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 646 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 647 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 648 |
+
|
| 649 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 650 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 651 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 652 |
+
|
| 653 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 654 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 655 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 656 |
+
|
| 657 |
+
self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
|
| 658 |
+
self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
|
| 659 |
+
|
| 660 |
+
# Log rewards
|
| 661 |
+
# Compute various statistics
|
| 662 |
+
chosen_rewards = chosen_log_ratios * self.beta
|
| 663 |
+
rejected_rewards = rejected_log_ratios * self.beta
|
| 664 |
+
self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
|
| 665 |
+
self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
|
| 666 |
+
|
| 667 |
+
# Calculate KL divergence for model and ref data
|
| 668 |
+
kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
|
| 669 |
+
kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
|
| 670 |
+
mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
|
| 671 |
+
self.stats["objective/kl"].append(gather_mean(mean_kl))
|
| 672 |
+
|
| 673 |
+
# Calculate entropy for model and ref data
|
| 674 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 675 |
+
entropy_ref_data = -model_logprobs_ref_data.sum(1)
|
| 676 |
+
mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
|
| 677 |
+
self.stats["objective/entropy"].append(gather_mean(mean_entropy))
|
| 678 |
+
|
| 679 |
+
# Calculate margins
|
| 680 |
+
margin = chosen_rewards - rejected_rewards
|
| 681 |
+
self.stats["rewards/margins"].append(gather_mean(margin.mean()))
|
| 682 |
+
|
| 683 |
+
# Calculate accuracy
|
| 684 |
+
accuracy = (margin > 0).float()
|
| 685 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
|
| 686 |
+
|
| 687 |
+
# Log EOS token statistics
|
| 688 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 689 |
+
ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 690 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 691 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
|
| 692 |
+
|
| 693 |
+
# Log alpha and beta
|
| 694 |
+
self.stats["alpha"].append(self.alpha)
|
| 695 |
+
self.stats["beta"].append(self.beta)
|
| 696 |
+
|
| 697 |
+
def training_step(
|
| 698 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 699 |
+
) -> torch.Tensor:
|
| 700 |
+
model.train()
|
| 701 |
+
|
| 702 |
+
# Apply chat template and tokenize the input
|
| 703 |
+
batch_size = len(next(iter(inputs.values())))
|
| 704 |
+
prompts = inputs["prompt"]
|
| 705 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 706 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 707 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 708 |
+
inputs = self.data_collator(inputs)
|
| 709 |
+
|
| 710 |
+
# need the prompt_ only
|
| 711 |
+
inputs = self._prepare_inputs(inputs)
|
| 712 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
| 713 |
+
prompts = {
|
| 714 |
+
"input_ids": inputs["prompt_input_ids"],
|
| 715 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
| 716 |
+
"raw": prompts,
|
| 717 |
+
}
|
| 718 |
+
del inputs
|
| 719 |
+
|
| 720 |
+
# Sample completions from both the model and the reference model
|
| 721 |
+
model_output, ref_output = self._generate_completions(prompts, model)
|
| 722 |
+
|
| 723 |
+
# Process model completions
|
| 724 |
+
model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
|
| 725 |
+
|
| 726 |
+
# Compute rewards
|
| 727 |
+
if self.reward_model is not None:
|
| 728 |
+
model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
|
| 729 |
+
chosen_mask = model_scores >= ref_scores
|
| 730 |
+
else:
|
| 731 |
+
model_scores, ref_scores = None, None
|
| 732 |
+
chosen_mask = self._compute_judge(model_data, ref_data, context_length)
|
| 733 |
+
|
| 734 |
+
# Compute logprobs
|
| 735 |
+
model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
|
| 736 |
+
self._compute_logprobs(model, model_data, ref_data, context_length)
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Compute loss
|
| 740 |
+
loss, dpo_losses, xpo_losses = self._compute_losses(
|
| 741 |
+
model_logprobs_model_data,
|
| 742 |
+
model_logprobs_ref_data,
|
| 743 |
+
ref_logprobs_ref_data,
|
| 744 |
+
ref_logprobs_model_data,
|
| 745 |
+
chosen_mask,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# Log everything
|
| 749 |
+
self._log_statistics(
|
| 750 |
+
model_data,
|
| 751 |
+
ref_data,
|
| 752 |
+
model_logprobs_model_data.detach(),
|
| 753 |
+
model_logprobs_ref_data.detach(),
|
| 754 |
+
ref_logprobs_ref_data,
|
| 755 |
+
ref_logprobs_model_data,
|
| 756 |
+
chosen_mask,
|
| 757 |
+
dpo_losses.detach(),
|
| 758 |
+
xpo_losses.detach(),
|
| 759 |
+
context_length,
|
| 760 |
+
model_scores,
|
| 761 |
+
ref_scores,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
if (
|
| 765 |
+
self.args.torch_empty_cache_steps is not None
|
| 766 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 767 |
+
):
|
| 768 |
+
empty_cache()
|
| 769 |
+
|
| 770 |
+
kwargs = {}
|
| 771 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
| 772 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 773 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
| 774 |
+
|
| 775 |
+
if self.args.n_gpu > 1:
|
| 776 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 777 |
+
|
| 778 |
+
if self.use_apex:
|
| 779 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 780 |
+
scaled_loss.backward()
|
| 781 |
+
else:
|
| 782 |
+
self.accelerator.backward(loss, **kwargs)
|
| 783 |
+
|
| 784 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
| 785 |
+
|
| 786 |
+
def create_model_card(
|
| 787 |
+
self,
|
| 788 |
+
model_name: Optional[str] = None,
|
| 789 |
+
dataset_name: Optional[str] = None,
|
| 790 |
+
tags: Union[str, list[str], None] = None,
|
| 791 |
+
):
|
| 792 |
+
"""
|
| 793 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 794 |
+
|
| 795 |
+
Args:
|
| 796 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 797 |
+
Name of the model.
|
| 798 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 799 |
+
Name of the dataset used for training.
|
| 800 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 801 |
+
Tags to be associated with the model card.
|
| 802 |
+
"""
|
| 803 |
+
if not self.is_world_process_zero():
|
| 804 |
+
return
|
| 805 |
+
|
| 806 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 807 |
+
base_model = self.model.config._name_or_path
|
| 808 |
+
else:
|
| 809 |
+
base_model = None
|
| 810 |
+
|
| 811 |
+
tags = tags or []
|
| 812 |
+
if isinstance(tags, str):
|
| 813 |
+
tags = [tags]
|
| 814 |
+
|
| 815 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 816 |
+
tags.append("unsloth")
|
| 817 |
+
|
| 818 |
+
citation = textwrap.dedent("""\
|
| 819 |
+
@article{jung2024binary,
|
| 820 |
+
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
|
| 821 |
+
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
|
| 822 |
+
year = 2024,
|
| 823 |
+
eprint = {arXiv:2405.21046}
|
| 824 |
+
}""")
|
| 825 |
+
|
| 826 |
+
model_card = generate_model_card(
|
| 827 |
+
base_model=base_model,
|
| 828 |
+
model_name=model_name,
|
| 829 |
+
hub_model_id=self.hub_model_id,
|
| 830 |
+
dataset_name=dataset_name,
|
| 831 |
+
tags=tags,
|
| 832 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 833 |
+
comet_url=get_comet_experiment_url(),
|
| 834 |
+
trainer_name="XPO",
|
| 835 |
+
trainer_citation=citation,
|
| 836 |
+
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
|
| 837 |
+
paper_id="2405.21046",
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 841 |
+
class UnslothXPOTrainer(_UnslothXPOTrainer):
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
|
| 845 |
+
|
| 846 |
+
Args:
|
| 847 |
+
model (`transformers.PreTrainedModel`):
|
| 848 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
| 849 |
+
ref_model (`PreTrainedModelWrapper`):
|
| 850 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 851 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 852 |
+
reward_model (`transformers.PreTrainedModel`):
|
| 853 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 854 |
+
judge (`BasePairwiseJudge`):
|
| 855 |
+
The judge to use for pairwise comparison of model completions.
|
| 856 |
+
args (`XPOConfig`):
|
| 857 |
+
The XPO config arguments to use for training.
|
| 858 |
+
data_collator (`transformers.DataCollator`):
|
| 859 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 860 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 861 |
+
train_dataset (`datasets.Dataset`):
|
| 862 |
+
The dataset to use for training.
|
| 863 |
+
eval_dataset (`datasets.Dataset`):
|
| 864 |
+
The dataset to use for evaluation.
|
| 865 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 866 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 867 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 868 |
+
reuse the fine-tuned model.
|
| 869 |
+
peft_config (`dict`):
|
| 870 |
+
The peft config to use for training.
|
| 871 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 872 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 873 |
+
a dictionary string to metric values.
|
| 874 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 875 |
+
The callbacks to use for training.
|
| 876 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 877 |
+
The optimizer and scheduler to use for training.
|
| 878 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 879 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 880 |
+
|
| 881 |
+
"""
|
| 882 |
+
def __init__(
|
| 883 |
+
self,
|
| 884 |
+
model = None,
|
| 885 |
+
ref_model = None,
|
| 886 |
+
reward_model = None,
|
| 887 |
+
judge = None,
|
| 888 |
+
args = None,
|
| 889 |
+
data_collator = None,
|
| 890 |
+
train_dataset = None,
|
| 891 |
+
eval_dataset = None,
|
| 892 |
+
processing_class = None,
|
| 893 |
+
peft_config = None,
|
| 894 |
+
compute_metrics = None,
|
| 895 |
+
callbacks = None,
|
| 896 |
+
preprocess_logits_for_metrics = None,
|
| 897 |
+
**kwargs
|
| 898 |
+
):
|
| 899 |
+
if args is None: args = UnslothXPOConfig()
|
| 900 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 901 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 902 |
+
force_float32 = False
|
| 903 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 904 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 905 |
+
force_float32 = True
|
| 906 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 907 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 908 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 909 |
+
from unsloth_zoo.utils import _get_dtype
|
| 910 |
+
dtype = _get_dtype(dtype)
|
| 911 |
+
float16 = dtype == torch.float16
|
| 912 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 913 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 914 |
+
if force_float32:
|
| 915 |
+
args.fp16 = False
|
| 916 |
+
args.bf16 = False
|
| 917 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 918 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 919 |
+
args.fp16 = float16
|
| 920 |
+
args.bf16 = not float16
|
| 921 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 922 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 923 |
+
args.eval_strategy = 'steps'
|
| 924 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 925 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 926 |
+
if ga_steps is not None and ga_steps > 1:
|
| 927 |
+
from transformers import __version__ as transformers_version
|
| 928 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 929 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 930 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 931 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 932 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 933 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 934 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 935 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 936 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 937 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 938 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 939 |
+
if force_float32:
|
| 940 |
+
args.bf16_full_eval = False
|
| 941 |
+
args.fp16_full_eval = False
|
| 942 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 943 |
+
args.bf16_full_eval = True
|
| 944 |
+
args.fp16_full_eval = False
|
| 945 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 946 |
+
args.bf16_full_eval = args.bf16
|
| 947 |
+
args.fp16_full_eval = args.fp16
|
| 948 |
+
_output_logits = False
|
| 949 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 950 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 951 |
+
if _output_logits:
|
| 952 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 953 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 954 |
+
pass
|
| 955 |
+
else:
|
| 956 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 957 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 958 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 959 |
+
max_seq_length = model.max_seq_length
|
| 960 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 961 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 962 |
+
model.for_training()
|
| 963 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 964 |
+
if 'processing_class' in locals():
|
| 965 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 966 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 967 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 968 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 969 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 970 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 971 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 972 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 973 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 974 |
+
else:
|
| 975 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 976 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 977 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 978 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 979 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 980 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 981 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 982 |
+
else:
|
| 983 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 984 |
+
other_metrics = []
|
| 985 |
+
|
| 986 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 987 |
+
PatchRLStatistics('xpo_trainer', other_metrics)
|
| 988 |
+
|
| 989 |
+
super().__init__(
|
| 990 |
+
model = model,
|
| 991 |
+
ref_model = ref_model,
|
| 992 |
+
reward_model = reward_model,
|
| 993 |
+
judge = judge,
|
| 994 |
+
args = args,
|
| 995 |
+
data_collator = data_collator,
|
| 996 |
+
train_dataset = train_dataset,
|
| 997 |
+
eval_dataset = eval_dataset,
|
| 998 |
+
processing_class = processing_class,
|
| 999 |
+
peft_config = peft_config,
|
| 1000 |
+
compute_metrics = compute_metrics,
|
| 1001 |
+
callbacks = callbacks,
|
| 1002 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 1003 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1004 |
+
self.neftune_hook_handle.remove()
|
| 1005 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1006 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1007 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1008 |
+
pass
|
| 1009 |
+
|
| 1010 |
+
pass
|
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc
ADDED
|
Binary file (91.7 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc
ADDED
|
Binary file (75.6 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
ADDED
|
Binary file (45.5 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d636dd9bd05907f328159064584e1667333117b7120ddbed1c3c316bc279cc36
|
| 3 |
+
size 103583
|
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc
ADDED
|
Binary file (37.7 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
ADDED
|
Binary file (78.5 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc
ADDED
|
Binary file (87.4 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
ADDED
|
Binary file (47.3 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc
ADDED
|
Binary file (75.6 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
ADDED
|
Binary file (67.2 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc
ADDED
|
Binary file (62.7 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc
ADDED
|
Binary file (36.5 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
ADDED
|
Binary file (54.2 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc
ADDED
|
Binary file (38.9 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc
ADDED
|
Binary file (48.1 kB). View file
|
|
|
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc
ADDED
|
Binary file (49.9 kB). View file
|
|
|
upload_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def upload_to_huggingface(model, tokenizer, repo_name, token):
|
| 5 |
+
"""
|
| 6 |
+
Upload a fine-tuned model and tokenizer to Hugging Face.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
model: The fine-tuned model to upload.
|
| 10 |
+
tokenizer: The tokenizer associated with the model.
|
| 11 |
+
repo_name (str): The name of the repository to create/update on Hugging Face.
|
| 12 |
+
token (str): Hugging Face API token.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
str: A message indicating the success or failure of the upload.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
# Save the model and tokenizer to a temporary directory
|
| 19 |
+
temp_dir = "temp_model"
|
| 20 |
+
model.save_pretrained(temp_dir)
|
| 21 |
+
tokenizer.save_pretrained(temp_dir)
|
| 22 |
+
|
| 23 |
+
# Initialize the Hugging Face API
|
| 24 |
+
api = HfApi()
|
| 25 |
+
|
| 26 |
+
# Create or update the repository
|
| 27 |
+
api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Upload the model and tokenizer files
|
| 30 |
+
api.upload_folder(
|
| 31 |
+
folder_path=temp_dir,
|
| 32 |
+
repo_id=repo_name,
|
| 33 |
+
token=token
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Clean up the temporary directory
|
| 37 |
+
for file in os.listdir(temp_dir):
|
| 38 |
+
os.remove(os.path.join(temp_dir, file))
|
| 39 |
+
os.rmdir(temp_dir)
|
| 40 |
+
|
| 41 |
+
return f"Model successfully uploaded to https://huggingface.co/{repo_name}"
|
| 42 |
+
except Exception as e:
|
| 43 |
+
return f"Error uploading model: {str(e)}"
|
| 44 |
+
|
| 45 |
+
def upload_gguf_to_huggingface(gguf_file_path, repo_name, token):
|
| 46 |
+
"""
|
| 47 |
+
Upload a GGUF converted model to Hugging Face.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
gguf_file_path (str): The path to the GGUF file.
|
| 51 |
+
repo_name (str): The name of the repository to create/update on Hugging Face.
|
| 52 |
+
token (str): Hugging Face API token.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
str: A message indicating the success or failure of the upload.
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
if not os.path.exists(gguf_file_path):
|
| 59 |
+
return f"Error: GGUF file not found at {gguf_file_path}"
|
| 60 |
+
|
| 61 |
+
# Initialize the Hugging Face API
|
| 62 |
+
api = HfApi()
|
| 63 |
+
|
| 64 |
+
# Create or update the repository
|
| 65 |
+
api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
# Upload the GGUF file
|
| 68 |
+
api.upload_file(
|
| 69 |
+
path_or_fileobj=gguf_file_path,
|
| 70 |
+
path_in_repo=os.path.basename(gguf_file_path),
|
| 71 |
+
repo_id=repo_name,
|
| 72 |
+
token=token
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return f"GGUF model successfully uploaded to https://huggingface.co/{repo_name}"
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return f"Error uploading GGUF model: {str(e)}"
|