Antonio Toro Jaén commited on
Commit
5d94b8c
·
1 Parent(s): 0828923
Files changed (35) hide show
  1. app.py +13 -6
  2. unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
  3. unsloth_compiled_cache/UnslothBCOTrainer.py +1818 -0
  4. unsloth_compiled_cache/UnslothCPOTrainer.py +1551 -0
  5. unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
  6. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  7. unsloth_compiled_cache/UnslothGKDTrainer.py +857 -0
  8. unsloth_compiled_cache/UnslothGRPOTrainer.py +1432 -0
  9. unsloth_compiled_cache/UnslothKTOTrainer.py +1834 -0
  10. unsloth_compiled_cache/UnslothNashMDTrainer.py +949 -0
  11. unsloth_compiled_cache/UnslothORPOTrainer.py +1537 -0
  12. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1263 -0
  13. unsloth_compiled_cache/UnslothPPOTrainer.py +1253 -0
  14. unsloth_compiled_cache/UnslothPRMTrainer.py +794 -0
  15. unsloth_compiled_cache/UnslothRLOOTrainer.py +1127 -0
  16. unsloth_compiled_cache/UnslothRewardTrainer.py +813 -0
  17. unsloth_compiled_cache/UnslothSFTTrainer.py +1025 -0
  18. unsloth_compiled_cache/UnslothXPOTrainer.py +1004 -0
  19. unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-312.pyc +0 -0
  20. unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  21. unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  22. unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-312.pyc +0 -0
  23. unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  24. unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  25. unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  26. unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  27. unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  28. unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  29. unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  30. unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  31. unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  32. unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  33. unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  34. unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  35. unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
app.py CHANGED
@@ -4,14 +4,21 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import re
5
  import os
6
  import csv
 
7
 
8
- from unsloth import FastLanguageModel
9
- model, tokenizer = FastLanguageModel.from_pretrained(
10
- model_name = "atorojaen/DeepSeekMisongynyLyrics", # Modelo base
11
- max_seq_length = 2048,
12
- dtype = torch.float16,
13
- load_in_4bit = True,
 
 
 
 
 
14
  )
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.eval() # <- sí se puede usar
17
 
 
4
  import re
5
  import os
6
  import csv
7
+ from huggingface_hub import login
8
 
9
+
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import torch
12
+
13
+ model_name = "atorojaen/DeepSeekMisongynyLyrics"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ device_map="auto", # o "cuda" si quieres forzarlo en GPU
19
+ torch_dtype=torch.float16 # si tu modelo lo soporta
20
  )
21
+
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  model.eval() # <- sí se puede usar
24
 
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothAlignPropConfig(AlignPropConfig):
44
+ """
45
+
46
+ Configuration class for the [`AlignPropTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (defaults to the file name without the extension).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed for reproducibility.
59
+ log_with (`str` or `None`, *optional*, defaults to `None`):
60
+ Log with either `"wandb"` or `"tensorboard"`. Check
61
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
+ log_image_freq (`int`, *optional*, defaults to `1`):
63
+ Frequency for logging images.
64
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the tracker (e.g., `wandb_project`).
66
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator.
68
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
+ Name of project to use for tracking.
72
+ logdir (`str`, *optional*, defaults to `"logs"`):
73
+ Top-level logging directory for checkpoint saving.
74
+ num_epochs (`int`, *optional*, defaults to `100`):
75
+ Number of epochs to train.
76
+ save_freq (`int`, *optional*, defaults to `1`):
77
+ Number of epochs between saving model checkpoints.
78
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
+ Number of checkpoints to keep before overwriting old ones.
80
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
+ Mixed precision training.
82
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
83
+ Allow `tf32` on Ampere GPUs.
84
+ resume_from (`str`, *optional*, defaults to `""`):
85
+ Path to resume training from a checkpoint.
86
+ sample_num_steps (`int`, *optional*, defaults to `50`):
87
+ Number of sampler inference steps.
88
+ sample_eta (`float`, *optional*, defaults to `1.0`):
89
+ Eta parameter for the DDIM sampler.
90
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
+ Classifier-free guidance weight.
92
+ train_batch_size (`int`, *optional*, defaults to `1`):
93
+ Batch size for training.
94
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
+ Learning rate.
98
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
+ Beta1 for Adam optimizer.
100
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
+ Beta2 for Adam optimizer.
102
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
+ Weight decay for Adam optimizer.
104
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
+ Epsilon value for Adam optimizer.
106
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
+ Number of gradient accumulation steps.
108
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
+ Maximum gradient norm for gradient clipping.
110
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
+ Comma-separated list of prompts to use as negative examples.
112
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
+ If `True`, randomized truncation to different diffusion timesteps is used.
114
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
+ Range of diffusion timesteps for randomized truncated backpropagation.
118
+ push_to_hub (`bool`, *optional*, defaults to `False`):
119
+ Whether to push the final model to the Hub.
120
+
121
+ """
122
+ vllm_sampling_params: Optional[Any] = field(
123
+ default = None,
124
+ metadata = {'help': 'vLLM SamplingParams'},
125
+ )
126
+ unsloth_num_chunks : Optional[int] = field(
127
+ default = -1,
128
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
+ )
130
+ def __init__(
131
+ self,
132
+ exp_name = 'app',
133
+ run_name = '',
134
+ seed = 3407,
135
+ log_with = None,
136
+ log_image_freq = 1,
137
+ tracker_project_name = 'trl',
138
+ logdir = 'logs',
139
+ num_epochs = 100,
140
+ save_freq = 1,
141
+ num_checkpoint_limit = 5,
142
+ mixed_precision = 'fp16',
143
+ allow_tf32 = True,
144
+ resume_from = '',
145
+ sample_num_steps = 50,
146
+ sample_eta = 1.0,
147
+ sample_guidance_scale = 5.0,
148
+ train_batch_size = 1,
149
+ train_use_8bit_adam = False,
150
+ train_learning_rate = 5e-05,
151
+ train_adam_beta1 = 0.9,
152
+ train_adam_beta2 = 0.999,
153
+ train_adam_weight_decay = 0.01,
154
+ train_adam_epsilon = 1e-08,
155
+ train_gradient_accumulation_steps = 2,
156
+ train_max_grad_norm = 1.0,
157
+ negative_prompts = None,
158
+ truncated_backprop_rand = True,
159
+ truncated_backprop_timestep = 49,
160
+ push_to_hub = False,
161
+ vllm_sampling_params = None,
162
+ unsloth_num_chunks = -1,
163
+ **kwargs,
164
+ ):
165
+
166
+ super().__init__(
167
+ exp_name = exp_name,
168
+ run_name = run_name,
169
+ seed = seed,
170
+ log_with = log_with,
171
+ log_image_freq = log_image_freq,
172
+ tracker_project_name = tracker_project_name,
173
+ logdir = logdir,
174
+ num_epochs = num_epochs,
175
+ save_freq = save_freq,
176
+ num_checkpoint_limit = num_checkpoint_limit,
177
+ mixed_precision = mixed_precision,
178
+ allow_tf32 = allow_tf32,
179
+ resume_from = resume_from,
180
+ sample_num_steps = sample_num_steps,
181
+ sample_eta = sample_eta,
182
+ sample_guidance_scale = sample_guidance_scale,
183
+ train_batch_size = train_batch_size,
184
+ train_use_8bit_adam = train_use_8bit_adam,
185
+ train_learning_rate = train_learning_rate,
186
+ train_adam_beta1 = train_adam_beta1,
187
+ train_adam_beta2 = train_adam_beta2,
188
+ train_adam_weight_decay = train_adam_weight_decay,
189
+ train_adam_epsilon = train_adam_epsilon,
190
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
+ train_max_grad_norm = train_max_grad_norm,
192
+ negative_prompts = negative_prompts,
193
+ truncated_backprop_rand = truncated_backprop_rand,
194
+ truncated_backprop_timestep = truncated_backprop_timestep,
195
+ push_to_hub = push_to_hub,**kwargs)
196
+ self.vllm_sampling_params = vllm_sampling_params
197
+ self.unsloth_num_chunks = unsloth_num_chunks
198
+ pass
199
+
200
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
+ """"""
202
+
203
+ _tag_names = ["trl", "alignprop"]
204
+
205
+ def __init__(
206
+ self,
207
+ config: AlignPropConfig,
208
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
+ prompt_function: Callable[[], tuple[str, Any]],
210
+ sd_pipeline: DDPOStableDiffusionPipeline,
211
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
+ ):
213
+ if image_samples_hook is None:
214
+ warn("No image_samples_hook provided; no images will be logged")
215
+
216
+ self.prompt_fn = prompt_function
217
+ self.reward_fn = reward_function
218
+ self.config = config
219
+ self.image_samples_callback = image_samples_hook
220
+
221
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
+
223
+ if self.config.resume_from:
224
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
+ # get the most recent checkpoint in this directory
227
+ checkpoints = list(
228
+ filter(
229
+ lambda x: "checkpoint_" in x,
230
+ os.listdir(self.config.resume_from),
231
+ )
232
+ )
233
+ if len(checkpoints) == 0:
234
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
+ self.config.resume_from = os.path.join(
237
+ self.config.resume_from,
238
+ f"checkpoint_{checkpoint_numbers[-1]}",
239
+ )
240
+
241
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
+
243
+ self.accelerator = Accelerator(
244
+ log_with=self.config.log_with,
245
+ mixed_precision=self.config.mixed_precision,
246
+ project_config=accelerator_project_config,
247
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
+ # the total number of optimizer steps to accumulate across.
250
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
+ **self.config.accelerator_kwargs,
252
+ )
253
+
254
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
+
256
+ if self.accelerator.is_main_process:
257
+ self.accelerator.init_trackers(
258
+ self.config.tracker_project_name,
259
+ config=dict(alignprop_trainer_config=config.to_dict())
260
+ if not is_using_tensorboard
261
+ else config.to_dict(),
262
+ init_kwargs=self.config.tracker_kwargs,
263
+ )
264
+
265
+ logger.info(f"\n{config}")
266
+
267
+ set_seed(self.config.seed, device_specific=True)
268
+
269
+ self.sd_pipeline = sd_pipeline
270
+
271
+ self.sd_pipeline.set_progress_bar_config(
272
+ position=1,
273
+ disable=not self.accelerator.is_local_main_process,
274
+ leave=False,
275
+ desc="Timestep",
276
+ dynamic_ncols=True,
277
+ )
278
+
279
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
+ # as these weights are only used for inference, keeping weights in full precision is not required.
281
+ if self.accelerator.mixed_precision == "fp16":
282
+ inference_dtype = torch.float16
283
+ elif self.accelerator.mixed_precision == "bf16":
284
+ inference_dtype = torch.bfloat16
285
+ else:
286
+ inference_dtype = torch.float32
287
+
288
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
+
292
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
293
+
294
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
+
297
+ # Enable TF32 for faster training on Ampere GPUs,
298
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
+ if self.config.allow_tf32:
300
+ torch.backends.cuda.matmul.allow_tf32 = True
301
+
302
+ self.optimizer = self._setup_optimizer(
303
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
+ )
305
+
306
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
+ self.sd_pipeline.tokenizer(
308
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
+ return_tensors="pt",
310
+ padding="max_length",
311
+ truncation=True,
312
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
313
+ ).input_ids.to(self.accelerator.device)
314
+ )[0]
315
+
316
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
+ # more memory
318
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
+
320
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
+ else:
324
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
+
326
+ if config.resume_from:
327
+ logger.info(f"Resuming from {config.resume_from}")
328
+ self.accelerator.load_state(config.resume_from)
329
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
+ else:
331
+ self.first_epoch = 0
332
+
333
+ def compute_rewards(self, prompt_image_pairs):
334
+ reward, reward_metadata = self.reward_fn(
335
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
+ )
337
+ return reward
338
+
339
+ def step(self, epoch: int, global_step: int):
340
+ """
341
+ Perform a single step of training.
342
+
343
+ Args:
344
+ epoch (int): The current epoch.
345
+ global_step (int): The current global step.
346
+
347
+ Side Effects:
348
+ - Model weights are updated
349
+ - Logs the statistics to the accelerator trackers.
350
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
+
352
+ Returns:
353
+ global_step (int): The updated global step.
354
+ """
355
+ info = defaultdict(list)
356
+
357
+ self.sd_pipeline.unet.train()
358
+
359
+ for _ in range(self.config.train_gradient_accumulation_steps):
360
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
+ prompt_image_pairs = self._generate_samples(
362
+ batch_size=self.config.train_batch_size,
363
+ )
364
+
365
+ rewards = self.compute_rewards(prompt_image_pairs)
366
+
367
+ prompt_image_pairs["rewards"] = rewards
368
+
369
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
+
371
+ loss = self.calculate_loss(rewards)
372
+
373
+ self.accelerator.backward(loss)
374
+
375
+ if self.accelerator.sync_gradients:
376
+ self.accelerator.clip_grad_norm_(
377
+ self.trainable_layers.parameters()
378
+ if not isinstance(self.trainable_layers, list)
379
+ else self.trainable_layers,
380
+ self.config.train_max_grad_norm,
381
+ )
382
+
383
+ self.optimizer.step()
384
+ self.optimizer.zero_grad()
385
+
386
+ info["reward_mean"].append(rewards_vis.mean())
387
+ info["reward_std"].append(rewards_vis.std())
388
+ info["loss"].append(loss.item())
389
+
390
+ # Checks if the accelerator has performed an optimization step behind the scenes
391
+ if self.accelerator.sync_gradients:
392
+ # log training-related stuff
393
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
+ info = self.accelerator.reduce(info, reduction="mean")
395
+ info.update({"epoch": epoch})
396
+ self.accelerator.log(info, step=global_step)
397
+ global_step += 1
398
+ info = defaultdict(list)
399
+ else:
400
+ raise ValueError(
401
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
+ )
403
+ # Logs generated images
404
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
+
407
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
+ self.accelerator.save_state()
409
+
410
+ return global_step
411
+
412
+ def calculate_loss(self, rewards):
413
+ """
414
+ Calculate the loss for a batch of an unpacked sample
415
+
416
+ Args:
417
+ rewards (torch.Tensor):
418
+ Differentiable reward scalars for each generated image, shape: [batch_size]
419
+
420
+ Returns:
421
+ loss (torch.Tensor)
422
+ (all of these are of shape (1,))
423
+ """
424
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
+ loss = 10.0 - (rewards).mean()
426
+ return loss
427
+
428
+ def loss(
429
+ self,
430
+ advantages: torch.Tensor,
431
+ clip_range: float,
432
+ ratio: torch.Tensor,
433
+ ):
434
+ unclipped_loss = -advantages * ratio
435
+ clipped_loss = -advantages * torch.clamp(
436
+ ratio,
437
+ 1.0 - clip_range,
438
+ 1.0 + clip_range,
439
+ )
440
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
+
442
+ def _setup_optimizer(self, trainable_layers_parameters):
443
+ if self.config.train_use_8bit_adam:
444
+ import bitsandbytes
445
+
446
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
447
+ else:
448
+ optimizer_cls = torch.optim.AdamW
449
+
450
+ return optimizer_cls(
451
+ trainable_layers_parameters,
452
+ lr=self.config.train_learning_rate,
453
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
+ weight_decay=self.config.train_adam_weight_decay,
455
+ eps=self.config.train_adam_epsilon,
456
+ )
457
+
458
+ def _save_model_hook(self, models, weights, output_dir):
459
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
+
462
+ def _load_model_hook(self, models, input_dir):
463
+ self.sd_pipeline.load_checkpoint(models, input_dir)
464
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
+
466
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
+ """
468
+ Generate samples from the model
469
+
470
+ Args:
471
+ batch_size (int): Batch size to use for sampling
472
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
+
474
+ Returns:
475
+ prompt_image_pairs (dict[Any])
476
+ """
477
+ prompt_image_pairs = {}
478
+
479
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
+
481
+ if prompts is None:
482
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
+ else:
484
+ prompt_metadata = [{} for _ in range(batch_size)]
485
+
486
+ prompt_ids = self.sd_pipeline.tokenizer(
487
+ prompts,
488
+ return_tensors="pt",
489
+ padding="max_length",
490
+ truncation=True,
491
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
492
+ ).input_ids.to(self.accelerator.device)
493
+
494
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
+
496
+ if with_grad:
497
+ sd_output = self.sd_pipeline.rgb_with_grad(
498
+ prompt_embeds=prompt_embeds,
499
+ negative_prompt_embeds=sample_neg_prompt_embeds,
500
+ num_inference_steps=self.config.sample_num_steps,
501
+ guidance_scale=self.config.sample_guidance_scale,
502
+ eta=self.config.sample_eta,
503
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
504
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
+ output_type="pt",
507
+ )
508
+ else:
509
+ sd_output = self.sd_pipeline(
510
+ prompt_embeds=prompt_embeds,
511
+ negative_prompt_embeds=sample_neg_prompt_embeds,
512
+ num_inference_steps=self.config.sample_num_steps,
513
+ guidance_scale=self.config.sample_guidance_scale,
514
+ eta=self.config.sample_eta,
515
+ output_type="pt",
516
+ )
517
+
518
+ images = sd_output.images
519
+
520
+ prompt_image_pairs["images"] = images
521
+ prompt_image_pairs["prompts"] = prompts
522
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
+
524
+ return prompt_image_pairs
525
+
526
+ def train(self, epochs: Optional[int] = None):
527
+ """
528
+ Train the model for a given number of epochs
529
+ """
530
+ global_step = 0
531
+ if epochs is None:
532
+ epochs = self.config.num_epochs
533
+ for epoch in range(self.first_epoch, epochs):
534
+ global_step = self.step(epoch, global_step)
535
+
536
+ def _save_pretrained(self, save_directory):
537
+ self.sd_pipeline.save_pretrained(save_directory)
538
+ self.create_model_card()
539
+
540
+ def create_model_card(
541
+ self,
542
+ model_name: Optional[str] = None,
543
+ dataset_name: Optional[str] = None,
544
+ tags: Union[str, list[str], None] = None,
545
+ ):
546
+ """
547
+ Creates a draft of a model card using the information available to the `Trainer`.
548
+
549
+ Args:
550
+ model_name (`str` or `None`, *optional*, defaults to `None`):
551
+ Name of the model.
552
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
+ Name of the dataset used for training.
554
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
+ Tags to be associated with the model card.
556
+ """
557
+ if not self.is_world_process_zero():
558
+ return
559
+
560
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
+ base_model = self.model.config._name_or_path
562
+ else:
563
+ base_model = None
564
+
565
+ tags = tags or []
566
+ if isinstance(tags, str):
567
+ tags = [tags]
568
+
569
+ if hasattr(self.model.config, "unsloth_version"):
570
+ tags.append("unsloth")
571
+
572
+ citation = textwrap.dedent("""\
573
+ @article{prabhudesai2024aligning,
574
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
+ year = 2024,
577
+ eprint = {arXiv:2310.03739}
578
+ }""")
579
+
580
+ model_card = generate_model_card(
581
+ base_model=base_model,
582
+ model_name=model_name,
583
+ hub_model_id=self.hub_model_id,
584
+ dataset_name=dataset_name,
585
+ tags=tags,
586
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
+ comet_url=get_comet_experiment_url(),
588
+ trainer_name="AlignProp",
589
+ trainer_citation=citation,
590
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
+ paper_id="2310.03739",
592
+ )
593
+
594
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
+ """
597
+
598
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
+ As of now only Stable Diffusion based pipelines are supported
601
+
602
+ Attributes:
603
+ config (`AlignPropConfig`):
604
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
+ Reward function to be used
607
+ prompt_function (`Callable[[], tuple[str, Any]]`):
608
+ Function to generate prompts to guide model
609
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
610
+ Stable Diffusion pipeline to be used for training.
611
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
+ Hook to be called to log images
613
+
614
+ """
615
+ def __init__(
616
+ self,
617
+ config,
618
+ reward_function,
619
+ prompt_function,
620
+ sd_pipeline,
621
+ image_samples_hook = None,
622
+ **kwargs
623
+ ):
624
+ if args is None: args = UnslothAlignPropConfig()
625
+ other_metrics = []
626
+
627
+ from unsloth_zoo.logging_utils import PatchRLStatistics
628
+ PatchRLStatistics('alignprop_trainer', other_metrics)
629
+
630
+ super().__init__(
631
+ config = config,
632
+ reward_function = reward_function,
633
+ prompt_function = prompt_function,
634
+ sd_pipeline = sd_pipeline,
635
+ image_samples_hook = image_samples_hook,**kwargs)
636
+
637
+ pass
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothBCOConfig(BCOConfig):
44
+ """
45
+
46
+ Configuration class for the [`BCOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
+ to use the default data collator.
56
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
+ Maximum length of the completion. This argument is required if you want to use the default data collator
60
+ and your model is an encoder-decoder.
61
+ beta (`float`, *optional*, defaults to `0.1`):
62
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
+ reference model.
64
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
+ Label pad token id. This argument is required if you want to use the default data collator.
66
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
67
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
68
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
+ This argument is required if you want to use the default data collator.
71
+ disable_dropout (`bool`, *optional*, defaults to `True`):
72
+ Whether to disable dropout in the model and reference model.
73
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
74
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
+ evaluation.
76
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
+ you need to specify if the model returned by the callable is an encoder-decoder model.
79
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
+ useful when training without the reference model to reduce the total GPU memory needed.
82
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
+ string.
85
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
+ from a string.
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
+ Number of prompts that are fed to density ratio classifier.
92
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
+
97
+ """
98
+ vllm_sampling_params: Optional[Any] = field(
99
+ default = None,
100
+ metadata = {'help': 'vLLM SamplingParams'},
101
+ )
102
+ unsloth_num_chunks : Optional[int] = field(
103
+ default = -1,
104
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
+ )
106
+ def __init__(
107
+ self,
108
+ output_dir = None,
109
+ overwrite_output_dir = None,
110
+ do_train = False,
111
+ do_eval = False,
112
+ do_predict = False,
113
+ eval_strategy = 'no',
114
+ prediction_loss_only = False,
115
+ per_device_train_batch_size = 4,
116
+ per_device_eval_batch_size = 4,
117
+ per_gpu_train_batch_size = None,
118
+ per_gpu_eval_batch_size = None,
119
+ gradient_accumulation_steps = 2,
120
+ eval_accumulation_steps = 2,
121
+ eval_delay = 0,
122
+ torch_empty_cache_steps = 250,
123
+ learning_rate = 5e-05,
124
+ weight_decay = 0.01,
125
+ adam_beta1 = 0.9,
126
+ adam_beta2 = 0.999,
127
+ adam_epsilon = 1e-08,
128
+ max_grad_norm = 1.0,
129
+ num_train_epochs = 3.0,
130
+ max_steps = -1,
131
+ lr_scheduler_type = 'linear',
132
+ warmup_ratio = 0.1,
133
+ warmup_steps = 0,
134
+ log_level = 'passive',
135
+ log_level_replica = 'warning',
136
+ log_on_each_node = True,
137
+ logging_dir = None,
138
+ logging_strategy = 'steps',
139
+ logging_first_step = False,
140
+ logging_steps = 1,
141
+ logging_nan_inf_filter = False,
142
+ save_strategy = 'steps',
143
+ save_steps = 500,
144
+ save_total_limit = None,
145
+ save_safetensors = True,
146
+ save_on_each_node = False,
147
+ save_only_model = False,
148
+ restore_callback_states_from_checkpoint = False,
149
+ no_cuda = False,
150
+ use_cpu = False,
151
+ use_mps_device = False,
152
+ seed = 3407,
153
+ data_seed = 3407,
154
+ jit_mode_eval = False,
155
+ use_ipex = False,
156
+ bf16 = False,
157
+ fp16 = False,
158
+ fp16_opt_level = 'O1',
159
+ half_precision_backend = 'auto',
160
+ bf16_full_eval = False,
161
+ fp16_full_eval = False,
162
+ tf32 = None,
163
+ local_rank = -1,
164
+ ddp_backend = None,
165
+ tpu_num_cores = None,
166
+ tpu_metrics_debug = False,
167
+ debug = '',
168
+ dataloader_drop_last = False,
169
+ eval_steps = None,
170
+ dataloader_num_workers = 0,
171
+ dataloader_prefetch_factor = None,
172
+ past_index = -1,
173
+ run_name = None,
174
+ disable_tqdm = None,
175
+ remove_unused_columns = True,
176
+ label_names = None,
177
+ load_best_model_at_end = False,
178
+ metric_for_best_model = None,
179
+ greater_is_better = None,
180
+ ignore_data_skip = False,
181
+ fsdp = '',
182
+ fsdp_min_num_params = 0,
183
+ fsdp_config = None,
184
+ tp_size = 0,
185
+ fsdp_transformer_layer_cls_to_wrap = None,
186
+ accelerator_config = None,
187
+ deepspeed = None,
188
+ label_smoothing_factor = 0.0,
189
+ optim = 'adamw_8bit',
190
+ optim_args = None,
191
+ adafactor = False,
192
+ group_by_length = False,
193
+ length_column_name = 'length',
194
+ report_to = None,
195
+ ddp_find_unused_parameters = None,
196
+ ddp_bucket_cap_mb = None,
197
+ ddp_broadcast_buffers = None,
198
+ dataloader_pin_memory = True,
199
+ dataloader_persistent_workers = False,
200
+ skip_memory_metrics = True,
201
+ use_legacy_prediction_loop = False,
202
+ push_to_hub = False,
203
+ resume_from_checkpoint = None,
204
+ hub_model_id = None,
205
+ hub_strategy = 'every_save',
206
+ hub_token = None,
207
+ hub_private_repo = None,
208
+ hub_always_push = False,
209
+ gradient_checkpointing = False,
210
+ gradient_checkpointing_kwargs = None,
211
+ include_inputs_for_metrics = False,
212
+ eval_do_concat_batches = True,
213
+ fp16_backend = 'auto',
214
+ push_to_hub_model_id = None,
215
+ push_to_hub_organization = None,
216
+ push_to_hub_token = None,
217
+ mp_parameters = '',
218
+ auto_find_batch_size = False,
219
+ full_determinism = False,
220
+ torchdynamo = None,
221
+ ray_scope = 'last',
222
+ ddp_timeout = 1800,
223
+ torch_compile = False,
224
+ torch_compile_backend = None,
225
+ torch_compile_mode = None,
226
+ include_tokens_per_second = False,
227
+ include_num_input_tokens_seen = False,
228
+ neftune_noise_alpha = None,
229
+ optim_target_modules = None,
230
+ batch_eval_metrics = False,
231
+ eval_on_start = False,
232
+ use_liger_kernel = False,
233
+ eval_use_gather_object = False,
234
+ average_tokens_across_devices = False,
235
+ max_length = 1024,
236
+ max_prompt_length = 512,
237
+ max_completion_length = None,
238
+ beta = 0.1,
239
+ label_pad_token_id = -100,
240
+ padding_value = None,
241
+ truncation_mode = 'keep_end',
242
+ disable_dropout = True,
243
+ generate_during_eval = False,
244
+ is_encoder_decoder = None,
245
+ precompute_ref_log_probs = False,
246
+ model_init_kwargs = None,
247
+ ref_model_init_kwargs = None,
248
+ dataset_num_proc = None,
249
+ prompt_sample_size = 1024,
250
+ min_density_ratio = 0.5,
251
+ max_density_ratio = 10.0,
252
+ vllm_sampling_params = None,
253
+ unsloth_num_chunks = -1,
254
+ **kwargs,
255
+ ):
256
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
257
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
258
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
259
+ output_dir = 'unsloth_training_checkpoints'
260
+ save_strategy = 'no'
261
+ if dataset_num_proc is None:
262
+ from multiprocessing import cpu_count
263
+ dataset_num_proc = cpu_count()
264
+
265
+ super().__init__(
266
+ output_dir = output_dir,
267
+ overwrite_output_dir = overwrite_output_dir,
268
+ do_train = do_train,
269
+ do_eval = do_eval,
270
+ do_predict = do_predict,
271
+ eval_strategy = eval_strategy,
272
+ prediction_loss_only = prediction_loss_only,
273
+ per_device_train_batch_size = per_device_train_batch_size,
274
+ per_device_eval_batch_size = per_device_eval_batch_size,
275
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
276
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
277
+ gradient_accumulation_steps = gradient_accumulation_steps,
278
+ eval_accumulation_steps = eval_accumulation_steps,
279
+ eval_delay = eval_delay,
280
+ torch_empty_cache_steps = torch_empty_cache_steps,
281
+ learning_rate = learning_rate,
282
+ weight_decay = weight_decay,
283
+ adam_beta1 = adam_beta1,
284
+ adam_beta2 = adam_beta2,
285
+ adam_epsilon = adam_epsilon,
286
+ max_grad_norm = max_grad_norm,
287
+ num_train_epochs = num_train_epochs,
288
+ max_steps = max_steps,
289
+ lr_scheduler_type = lr_scheduler_type,
290
+ warmup_ratio = warmup_ratio,
291
+ warmup_steps = warmup_steps,
292
+ log_level = log_level,
293
+ log_level_replica = log_level_replica,
294
+ log_on_each_node = log_on_each_node,
295
+ logging_dir = logging_dir,
296
+ logging_strategy = logging_strategy,
297
+ logging_first_step = logging_first_step,
298
+ logging_steps = logging_steps,
299
+ logging_nan_inf_filter = logging_nan_inf_filter,
300
+ save_strategy = save_strategy,
301
+ save_steps = save_steps,
302
+ save_total_limit = save_total_limit,
303
+ save_safetensors = save_safetensors,
304
+ save_on_each_node = save_on_each_node,
305
+ save_only_model = save_only_model,
306
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
307
+ no_cuda = no_cuda,
308
+ use_cpu = use_cpu,
309
+ use_mps_device = use_mps_device,
310
+ seed = seed,
311
+ data_seed = data_seed,
312
+ jit_mode_eval = jit_mode_eval,
313
+ use_ipex = use_ipex,
314
+ bf16 = bf16,
315
+ fp16 = fp16,
316
+ fp16_opt_level = fp16_opt_level,
317
+ half_precision_backend = half_precision_backend,
318
+ bf16_full_eval = bf16_full_eval,
319
+ fp16_full_eval = fp16_full_eval,
320
+ tf32 = tf32,
321
+ local_rank = local_rank,
322
+ ddp_backend = ddp_backend,
323
+ tpu_num_cores = tpu_num_cores,
324
+ tpu_metrics_debug = tpu_metrics_debug,
325
+ debug = debug,
326
+ dataloader_drop_last = dataloader_drop_last,
327
+ eval_steps = eval_steps,
328
+ dataloader_num_workers = dataloader_num_workers,
329
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
330
+ past_index = past_index,
331
+ run_name = run_name,
332
+ disable_tqdm = disable_tqdm,
333
+ remove_unused_columns = remove_unused_columns,
334
+ label_names = label_names,
335
+ load_best_model_at_end = load_best_model_at_end,
336
+ metric_for_best_model = metric_for_best_model,
337
+ greater_is_better = greater_is_better,
338
+ ignore_data_skip = ignore_data_skip,
339
+ fsdp = fsdp,
340
+ fsdp_min_num_params = fsdp_min_num_params,
341
+ fsdp_config = fsdp_config,
342
+ tp_size = tp_size,
343
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
344
+ accelerator_config = accelerator_config,
345
+ deepspeed = deepspeed,
346
+ label_smoothing_factor = label_smoothing_factor,
347
+ optim = optim,
348
+ optim_args = optim_args,
349
+ adafactor = adafactor,
350
+ group_by_length = group_by_length,
351
+ length_column_name = length_column_name,
352
+ report_to = report_to,
353
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
354
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
355
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
356
+ dataloader_pin_memory = dataloader_pin_memory,
357
+ dataloader_persistent_workers = dataloader_persistent_workers,
358
+ skip_memory_metrics = skip_memory_metrics,
359
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
360
+ push_to_hub = push_to_hub,
361
+ resume_from_checkpoint = resume_from_checkpoint,
362
+ hub_model_id = hub_model_id,
363
+ hub_strategy = hub_strategy,
364
+ hub_token = hub_token,
365
+ hub_private_repo = hub_private_repo,
366
+ hub_always_push = hub_always_push,
367
+ gradient_checkpointing = gradient_checkpointing,
368
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
369
+ include_inputs_for_metrics = include_inputs_for_metrics,
370
+ eval_do_concat_batches = eval_do_concat_batches,
371
+ fp16_backend = fp16_backend,
372
+ push_to_hub_model_id = push_to_hub_model_id,
373
+ push_to_hub_organization = push_to_hub_organization,
374
+ push_to_hub_token = push_to_hub_token,
375
+ mp_parameters = mp_parameters,
376
+ auto_find_batch_size = auto_find_batch_size,
377
+ full_determinism = full_determinism,
378
+ torchdynamo = torchdynamo,
379
+ ray_scope = ray_scope,
380
+ ddp_timeout = ddp_timeout,
381
+ torch_compile = torch_compile,
382
+ torch_compile_backend = torch_compile_backend,
383
+ torch_compile_mode = torch_compile_mode,
384
+ include_tokens_per_second = include_tokens_per_second,
385
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
386
+ neftune_noise_alpha = neftune_noise_alpha,
387
+ optim_target_modules = optim_target_modules,
388
+ batch_eval_metrics = batch_eval_metrics,
389
+ eval_on_start = eval_on_start,
390
+ use_liger_kernel = use_liger_kernel,
391
+ eval_use_gather_object = eval_use_gather_object,
392
+ average_tokens_across_devices = average_tokens_across_devices,
393
+ max_length = max_length,
394
+ max_prompt_length = max_prompt_length,
395
+ max_completion_length = max_completion_length,
396
+ beta = beta,
397
+ label_pad_token_id = label_pad_token_id,
398
+ padding_value = padding_value,
399
+ truncation_mode = truncation_mode,
400
+ disable_dropout = disable_dropout,
401
+ generate_during_eval = generate_during_eval,
402
+ is_encoder_decoder = is_encoder_decoder,
403
+ precompute_ref_log_probs = precompute_ref_log_probs,
404
+ model_init_kwargs = model_init_kwargs,
405
+ ref_model_init_kwargs = ref_model_init_kwargs,
406
+ dataset_num_proc = dataset_num_proc,
407
+ prompt_sample_size = prompt_sample_size,
408
+ min_density_ratio = min_density_ratio,
409
+ max_density_ratio = max_density_ratio,**kwargs)
410
+ self.vllm_sampling_params = vllm_sampling_params
411
+ self.unsloth_num_chunks = unsloth_num_chunks
412
+ pass
413
+
414
+ class _UnslothBCOTrainer(Trainer):
415
+ r""""""
416
+
417
+ _tag_names = ["trl", "bco"]
418
+
419
+ def __init__(
420
+ self,
421
+ model: Union[PreTrainedModel, nn.Module, str] = None,
422
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
423
+ args: BCOConfig = None,
424
+ train_dataset: Optional[Dataset] = None,
425
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
426
+ processing_class: Optional[
427
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
428
+ ] = None,
429
+ data_collator: Optional[DataCollator] = None,
430
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
431
+ callbacks: Optional[list[TrainerCallback]] = None,
432
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
433
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
434
+ peft_config: Optional[dict] = None,
435
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
436
+ model_adapter_name: Optional[str] = None,
437
+ ref_adapter_name: Optional[str] = None,
438
+ embedding_func: Optional[Callable] = None,
439
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
440
+ ):
441
+ if not is_sklearn_available():
442
+ raise ImportError(
443
+ "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
444
+ )
445
+
446
+ if type(args) is TrainingArguments:
447
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
448
+
449
+ if not isinstance(model, str) and ref_model is model:
450
+ raise ValueError(
451
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
452
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
453
+ )
454
+
455
+ if args.model_init_kwargs is None:
456
+ model_init_kwargs = {}
457
+ elif not isinstance(model, str):
458
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
459
+ else:
460
+ model_init_kwargs = args.model_init_kwargs
461
+ torch_dtype = model_init_kwargs.get("torch_dtype")
462
+ if torch_dtype is not None:
463
+ # Convert to `torch.dtype` if an str is passed
464
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
465
+ torch_dtype = getattr(torch, torch_dtype)
466
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
467
+ raise ValueError(
468
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
469
+ )
470
+ model_init_kwargs["torch_dtype"] = torch_dtype
471
+
472
+ if args.ref_model_init_kwargs is None:
473
+ ref_model_init_kwargs = {}
474
+ elif not isinstance(ref_model, str):
475
+ raise ValueError(
476
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
477
+ )
478
+ else:
479
+ ref_model_init_kwargs = args.ref_model_init_kwargs
480
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
481
+ if torch_dtype is not None:
482
+ # Convert to `torch.dtype` if an str is passed
483
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
484
+ torch_dtype = getattr(torch, torch_dtype)
485
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
486
+ raise ValueError(
487
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
488
+ )
489
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
490
+
491
+ if isinstance(model, str):
492
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
493
+
494
+ if isinstance(ref_model, str):
495
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
496
+
497
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
498
+ # has been called in order to properly call autocast if needed.
499
+ self._peft_has_been_casted_to_bf16 = False
500
+
501
+ if not is_peft_available() and peft_config is not None:
502
+ raise ValueError(
503
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
504
+ )
505
+ elif is_peft_available() and peft_config is not None:
506
+ # if model is a peft model and we have a peft_config, we merge and unload it first
507
+ if isinstance(model, PeftModel):
508
+ model = model.merge_and_unload()
509
+
510
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
511
+ _support_gc_kwargs = hasattr(
512
+ args, "gradient_checkpointing_kwargs"
513
+ ) and "gradient_checkpointing_kwargs" in list(
514
+ inspect.signature(prepare_model_for_kbit_training).parameters
515
+ )
516
+
517
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
518
+
519
+ if _support_gc_kwargs:
520
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
521
+
522
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
523
+ elif getattr(args, "gradient_checkpointing", False):
524
+ # For backward compatibility with older versions of transformers
525
+ if hasattr(model, "enable_input_require_grads"):
526
+ model.enable_input_require_grads()
527
+ else:
528
+
529
+ def make_inputs_require_grad(module, input, output):
530
+ output.requires_grad_(True)
531
+
532
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
533
+
534
+ # get peft model with the given config
535
+ model = model
536
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
537
+ peft_module_casting_to_bf16(model)
538
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
539
+ self._peft_has_been_casted_to_bf16 = True
540
+
541
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
542
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
543
+ # fail or completely fail.
544
+ elif getattr(args, "gradient_checkpointing", False):
545
+ # For backward compatibility with older versions of transformers
546
+ if hasattr(model, "enable_input_require_grads"):
547
+ model.enable_input_require_grads()
548
+ else:
549
+
550
+ def make_inputs_require_grad(module, input, output):
551
+ output.requires_grad_(True)
552
+
553
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
554
+
555
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
556
+ raise ValueError(
557
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
558
+ " Please install `wandb` or `comet-ml` to resolve."
559
+ )
560
+
561
+ if model is not None:
562
+ self.is_encoder_decoder = model.config.is_encoder_decoder
563
+ elif args.is_encoder_decoder is None:
564
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
565
+ else:
566
+ self.is_encoder_decoder = args.is_encoder_decoder
567
+
568
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
569
+ self.model_adapter_name = model_adapter_name
570
+ self.ref_adapter_name = ref_adapter_name
571
+
572
+ if ref_model:
573
+ self.ref_model = ref_model
574
+ elif self.is_peft_model or args.precompute_ref_log_probs:
575
+ # The `model` with adapters turned off will be used as the reference model
576
+ self.ref_model = None
577
+ else:
578
+ self.ref_model = create_reference_model(model)
579
+
580
+ if processing_class is None:
581
+ raise ValueError(
582
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
583
+ )
584
+ if args.max_length is None:
585
+ warnings.warn(
586
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
587
+ "It will be set to `512` by default, but you should do it yourself in the future.",
588
+ UserWarning,
589
+ )
590
+ max_length = 512
591
+ if args.max_length is not None:
592
+ max_length = args.max_length
593
+
594
+ if args.max_prompt_length is None:
595
+ warnings.warn(
596
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
597
+ "It will be set to `128` by default, but you should do it yourself in the future.",
598
+ UserWarning,
599
+ )
600
+ max_prompt_length = 128
601
+ if args.max_prompt_length is not None:
602
+ max_prompt_length = args.max_prompt_length
603
+
604
+ max_completion_length = None
605
+ if args.max_completion_length is None and self.is_encoder_decoder:
606
+ warnings.warn(
607
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
608
+ " it will be set to `128` by default, but you should do it yourself in the future.",
609
+ UserWarning,
610
+ )
611
+ max_completion_length = 128
612
+ if args.max_completion_length is not None and self.is_encoder_decoder:
613
+ max_completion_length = args.max_completion_length
614
+
615
+ if data_collator is None:
616
+ data_collator = DPODataCollatorWithPadding(
617
+ pad_token_id=processing_class.pad_token_id,
618
+ label_pad_token_id=args.label_pad_token_id,
619
+ is_encoder_decoder=self.is_encoder_decoder,
620
+ )
621
+
622
+ if args.remove_unused_columns:
623
+ args.remove_unused_columns = False
624
+ # warn users
625
+ warnings.warn(
626
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
627
+ " we have set it for you, but you should do it yourself in the future.",
628
+ UserWarning,
629
+ )
630
+
631
+ self.use_dpo_data_collator = True
632
+ else:
633
+ self.use_dpo_data_collator = False
634
+
635
+ # Disable dropout in the model and reference model
636
+ if args.disable_dropout:
637
+ disable_dropout_in_model(model)
638
+ if self.ref_model is not None:
639
+ disable_dropout_in_model(self.ref_model)
640
+
641
+ self.max_length = max_length
642
+ self.generate_during_eval = args.generate_during_eval
643
+ self.label_pad_token_id = args.label_pad_token_id
644
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
645
+ self.max_prompt_length = max_prompt_length
646
+ self.truncation_mode = args.truncation_mode
647
+ self.max_completion_length = max_completion_length
648
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
649
+
650
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
651
+ # keep track of first called to avoid computation of future calls
652
+ self._precomputed_train_ref_log_probs = False
653
+ self._precomputed_eval_ref_log_probs = False
654
+
655
+ # metric
656
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
657
+
658
+ # BCO parameter
659
+ self.beta = args.beta
660
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
661
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
662
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
663
+ warnings.warn(
664
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
665
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
666
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
667
+ "loss.",
668
+ UserWarning,
669
+ )
670
+
671
+ # Underlying Distribution Matching argument
672
+ self.embedding_func = embedding_func
673
+ self.embedding_tokenizer = embedding_tokenizer
674
+
675
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
676
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
677
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
678
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
679
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
680
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
681
+ # issued.
682
+ model.warnings_issued["estimate_tokens"] = True
683
+
684
+ with PartialState().local_main_process_first():
685
+ # Apply the chat template if needed
686
+ train_dataset = train_dataset.map(
687
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
688
+ )
689
+ if eval_dataset is not None:
690
+ eval_dataset = eval_dataset.map(
691
+ maybe_apply_chat_template,
692
+ fn_kwargs={"tokenizer": processing_class},
693
+ num_proc=args.dataset_num_proc,
694
+ )
695
+ # Shuffle the datasets
696
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
697
+ if eval_dataset is not None:
698
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
699
+ # Tokenize and prepare the training datasets
700
+ train_dataset = train_dataset.map(
701
+ _tokenize,
702
+ batched=True,
703
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
704
+ num_proc=args.dataset_num_proc,
705
+ desc="Tokenizing train dataset",
706
+ )
707
+
708
+ # Prepare the datasets
709
+ fn_kwargs = {
710
+ "prefix": "",
711
+ "is_encoder_decoder": self.is_encoder_decoder,
712
+ "tokenizer": processing_class,
713
+ "max_length": self.max_length,
714
+ "truncation_mode": self.truncation_mode,
715
+ "label_pad_token_id": self.label_pad_token_id,
716
+ "max_prompt_length": self.max_prompt_length,
717
+ "max_completion_length": self.max_completion_length,
718
+ }
719
+ train_dataset = train_dataset.map(
720
+ _process_tokens,
721
+ fn_kwargs=fn_kwargs,
722
+ num_proc=args.dataset_num_proc,
723
+ desc="Processing tokenized train dataset",
724
+ )
725
+
726
+ if eval_dataset is not None:
727
+ # Tokenize
728
+ eval_dataset = eval_dataset.map(
729
+ _tokenize,
730
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
731
+ batched=True,
732
+ num_proc=args.dataset_num_proc,
733
+ desc="Tokenizing eval dataset",
734
+ )
735
+
736
+ # Process
737
+ fn_kwargs = {
738
+ "prefix": "",
739
+ "is_encoder_decoder": self.is_encoder_decoder,
740
+ "tokenizer": processing_class,
741
+ "max_length": self.max_length,
742
+ "truncation_mode": self.truncation_mode,
743
+ "label_pad_token_id": self.label_pad_token_id,
744
+ "max_prompt_length": self.max_prompt_length,
745
+ "max_completion_length": self.max_completion_length,
746
+ }
747
+ eval_dataset = eval_dataset.map(
748
+ _process_tokens,
749
+ fn_kwargs=fn_kwargs,
750
+ num_proc=args.dataset_num_proc,
751
+ desc="Processing tokenized eval dataset",
752
+ )
753
+
754
+ desirable = train_dataset.filter(
755
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
756
+ )
757
+ undesirable = train_dataset.filter(
758
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
759
+ )
760
+
761
+ desirable = desirable.shuffle(seed=args.data_seed)
762
+ undesirable = undesirable.shuffle(seed=args.data_seed)
763
+
764
+ super().__init__(
765
+ model=model,
766
+ args=args,
767
+ data_collator=data_collator,
768
+ train_dataset=train_dataset,
769
+ eval_dataset=eval_dataset,
770
+ processing_class=processing_class,
771
+ model_init=model_init,
772
+ compute_metrics=compute_metrics,
773
+ callbacks=callbacks,
774
+ optimizers=optimizers,
775
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
776
+ )
777
+
778
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
779
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
780
+ # self.model_accepts_loss_kwargs to False to enable scaling.
781
+ self.model_accepts_loss_kwargs = False
782
+
783
+ # Add tags for models that have been loaded with the correct transformers version
784
+ if hasattr(self.model, "add_model_tags"):
785
+ self.model.add_model_tags(self._tag_names)
786
+
787
+ if not hasattr(self, "accelerator"):
788
+ raise AttributeError(
789
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
790
+ )
791
+
792
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
793
+ if self.is_deepspeed_enabled:
794
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
795
+ raise ValueError(
796
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
797
+ )
798
+
799
+ if self.ref_model is None:
800
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
801
+ raise ValueError(
802
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
803
+ )
804
+ else:
805
+ if self.is_deepspeed_enabled:
806
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
807
+ else:
808
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
809
+
810
+ self.running = RunningMoments(accelerator=self.accelerator)
811
+
812
+ if self.embedding_func is None:
813
+ return
814
+
815
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
816
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
817
+
818
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
819
+ labels = torch.cat(
820
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
821
+ )
822
+
823
+ self.clf = LogisticRegression(class_weight="balanced").fit(
824
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
825
+ )
826
+
827
+ @property
828
+ def match_underlying_distribution(self):
829
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
830
+
831
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
832
+ """
833
+ Calculates the probability if the given prompt embedding is from desirable dataset.
834
+ This function calculates the probability in the process and ensemble across processes.
835
+ """
836
+ dtype = prompt_embeddings.dtype
837
+ device = prompt_embeddings.device
838
+ rank = self.accelerator.process_index
839
+
840
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
841
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
842
+ )
843
+ sample_size = padded_prompt_embeddings.shape[0]
844
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
845
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
846
+
847
+ # cannot predict for all empty values
848
+ if prompt_embeddings.shape[0] == 0:
849
+ return torch.tensor([], device=device, dtype=dtype)
850
+
851
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
852
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
853
+ prob = self.accelerator.reduce(prob, reduction="mean")
854
+
855
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
856
+ prob = prob[nonzero]
857
+
858
+ return prob
859
+
860
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
861
+ """
862
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
863
+ and applies self.embedding_func
864
+ """
865
+ input_ids = torch.where(
866
+ input_ids == self.processing_class.pad_token_id,
867
+ self.embedding_tokenizer.pad_token_id,
868
+ input_ids,
869
+ )
870
+
871
+ with torch.no_grad():
872
+ embeddings = self.embedding_func(
873
+ input_ids=input_ids,
874
+ attention_mask=attention_mask,
875
+ )
876
+
877
+ return embeddings
878
+
879
+ def _get_prompt_embeddings(
880
+ self, batch: dict[str, Union[list, torch.LongTensor]]
881
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
882
+ """Extract embeddings from frozen embedding model"""
883
+
884
+ if not self.match_underlying_distribution:
885
+ return None, None
886
+
887
+ embeddings = self._vectorize_prompt(
888
+ input_ids=batch["embedding_input_ids"],
889
+ attention_mask=batch["embedding_attention_mask"],
890
+ )
891
+
892
+ chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
893
+ rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
894
+
895
+ chosen_embeddings = embeddings[chosen_idx, ...]
896
+ rejected_embeddings = embeddings[rejected_idx, ...]
897
+
898
+ return (chosen_embeddings, rejected_embeddings)
899
+
900
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
901
+ """
902
+ Sample instances from dataset and get prompt embeddings.
903
+ Used for density ratio classifier training.
904
+ """
905
+ n_samples = min(len(dataset), sample_size)
906
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
907
+
908
+ embedding_dataset = dataset.select(rand_indices)
909
+
910
+ dataloader_params = {
911
+ "batch_size": self.args.per_device_train_batch_size,
912
+ "collate_fn": self.data_collator,
913
+ "num_workers": self.args.dataloader_num_workers,
914
+ "pin_memory": self.args.dataloader_pin_memory,
915
+ "shuffle": False,
916
+ }
917
+
918
+ # prepare dataloader
919
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
920
+
921
+ with torch.no_grad():
922
+ all_embeddings = torch.empty(0)
923
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
924
+ embeddings = self._vectorize_prompt(
925
+ input_ids=padded_batch["embedding_input_ids"],
926
+ attention_mask=padded_batch["embedding_attention_mask"],
927
+ )
928
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
929
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
930
+
931
+ return all_embeddings
932
+
933
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
934
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
935
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
936
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
937
+
938
+ if model is not None:
939
+ if hasattr(model, "config"):
940
+ hidden_size = (
941
+ max(model.config.hidden_sizes)
942
+ if getattr(model.config, "hidden_sizes", None)
943
+ else getattr(model.config, "hidden_size", None)
944
+ )
945
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
946
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
947
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
948
+ config_kwargs.update(
949
+ {
950
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
951
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
952
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
953
+ }
954
+ )
955
+
956
+ # If ZeRO-3 is used, we shard both the active and reference model.
957
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
958
+ if config_kwargs["zero_optimization"]["stage"] != 3:
959
+ config_kwargs["zero_optimization"]["stage"] = 0
960
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
961
+ model.eval()
962
+ return model
963
+
964
+ def _save_optimizer_and_scheduler(self, output_dir):
965
+ super()._save_optimizer_and_scheduler(output_dir)
966
+
967
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
968
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
969
+
970
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
971
+
972
+ if self.match_underlying_distribution:
973
+ torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
974
+
975
+ def _load_optimizer_and_scheduler(self, checkpoint):
976
+ super()._load_optimizer_and_scheduler(checkpoint)
977
+
978
+ if checkpoint is None:
979
+ return
980
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
981
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
982
+ if os.path.isfile(running_file):
983
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
984
+
985
+ if self.match_underlying_distribution:
986
+ clf_file = os.path.join(checkpoint, CLF_NAME)
987
+ if os.path.isfile(running_file):
988
+ self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
989
+
990
+ @contextmanager
991
+ def null_ref_context(self):
992
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
993
+ with (
994
+ self.accelerator.unwrap_model(self.model).disable_adapter()
995
+ if self.is_peft_model and not self.ref_adapter_name
996
+ else nullcontext()
997
+ ):
998
+ if self.ref_adapter_name:
999
+ self.model.set_adapter(self.ref_adapter_name)
1000
+ yield
1001
+ if self.ref_adapter_name:
1002
+ self.model.set_adapter(self.model_adapter_name or "default")
1003
+
1004
+ def get_train_dataloader(self) -> DataLoader:
1005
+ """
1006
+ Returns the training [`~torch.utils.data.DataLoader`].
1007
+
1008
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1009
+ """
1010
+
1011
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1012
+ dataloader_params = {
1013
+ "batch_size": self.args.per_device_train_batch_size,
1014
+ "collate_fn": self.data_collator,
1015
+ "num_workers": self.args.dataloader_num_workers,
1016
+ "pin_memory": self.args.dataloader_pin_memory,
1017
+ "shuffle": False,
1018
+ }
1019
+
1020
+ # prepare dataloader
1021
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1022
+ reference_completion_logps = []
1023
+
1024
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1025
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1026
+
1027
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1028
+ reference_completion_logps.append(reference_completion_logp.cpu())
1029
+
1030
+ self.train_dataset = self.train_dataset.add_column(
1031
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1032
+ )
1033
+
1034
+ self._precomputed_train_ref_log_probs = True
1035
+
1036
+ return super().get_train_dataloader()
1037
+
1038
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1039
+ """
1040
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1041
+
1042
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1043
+
1044
+ Args:
1045
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1046
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1047
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1048
+ """
1049
+ if eval_dataset is None and self.eval_dataset is None:
1050
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1051
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1052
+
1053
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1054
+ dataloader_params = {
1055
+ "batch_size": self.args.per_device_eval_batch_size,
1056
+ "collate_fn": self.data_collator,
1057
+ "num_workers": self.args.dataloader_num_workers,
1058
+ "pin_memory": self.args.dataloader_pin_memory,
1059
+ "shuffle": False,
1060
+ }
1061
+
1062
+ # prepare dataloader
1063
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1064
+
1065
+ reference_completion_logps = []
1066
+
1067
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1068
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1069
+
1070
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1071
+ reference_completion_logps.append(reference_completion_logp.cpu())
1072
+
1073
+ eval_dataset = eval_dataset.add_column(
1074
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1075
+ )
1076
+
1077
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1078
+ if self.eval_dataset is not None:
1079
+ self.eval_dataset = eval_dataset
1080
+ self._precomputed_eval_ref_log_probs = True
1081
+
1082
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1083
+
1084
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1085
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1086
+ with torch.no_grad():
1087
+ if self.ref_model is None:
1088
+ with self.null_ref_context():
1089
+ if self.is_encoder_decoder:
1090
+ completion_logits = self.model(
1091
+ padded_batch["prompt_input_ids"],
1092
+ attention_mask=padded_batch["prompt_attention_mask"],
1093
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1094
+ labels=padded_batch["completion_labels"],
1095
+ ).logits
1096
+
1097
+ else:
1098
+ completion_logits = self.model(
1099
+ padded_batch["completion_input_ids"],
1100
+ attention_mask=padded_batch["completion_attention_mask"],
1101
+ ).logits
1102
+
1103
+ else:
1104
+ if self.is_encoder_decoder:
1105
+ completion_logits = self.ref_model(
1106
+ padded_batch["prompt_input_ids"],
1107
+ attention_mask=padded_batch["prompt_attention_mask"],
1108
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1109
+ labels=padded_batch["completion_labels"],
1110
+ ).logits
1111
+
1112
+ else:
1113
+ completion_logits = self.ref_model(
1114
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1115
+ ).logits
1116
+
1117
+ completion_logps = self.get_batch_logps(
1118
+ completion_logits,
1119
+ padded_batch["completion_labels"],
1120
+ average_log_prob=False,
1121
+ is_encoder_decoder=self.is_encoder_decoder,
1122
+ label_pad_token_id=self.label_pad_token_id,
1123
+ )
1124
+
1125
+ return completion_logps
1126
+
1127
+ @staticmethod
1128
+ def get_batch_logps(
1129
+ logits: torch.FloatTensor,
1130
+ labels: torch.LongTensor,
1131
+ average_log_prob: bool = False,
1132
+ label_pad_token_id: int = -100,
1133
+ is_encoder_decoder: bool = False,
1134
+ ) -> torch.FloatTensor:
1135
+ """Compute the log probabilities of the given labels under the given logits.
1136
+
1137
+ Args:
1138
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1139
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1140
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1141
+
1142
+ Returns:
1143
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1144
+ """
1145
+ if logits.shape[:-1] != labels.shape:
1146
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1147
+
1148
+ if not is_encoder_decoder:
1149
+ labels = labels[:, 1:].clone()
1150
+ logits = logits[:, :-1, :]
1151
+ else:
1152
+ # Fixes end-dec RuntimeError
1153
+ labels = labels.clone()
1154
+
1155
+ loss_mask = labels != label_pad_token_id
1156
+
1157
+ # dummy token; we'll ignore the losses on these tokens later
1158
+ labels[labels == label_pad_token_id] = 0
1159
+
1160
+ per_token_logps = selective_log_softmax(logits, labels)
1161
+
1162
+ if average_log_prob:
1163
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1164
+ else:
1165
+ return (per_token_logps * loss_mask).sum(-1)
1166
+
1167
+ def forward(
1168
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1169
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1170
+ model_kwargs = (
1171
+ {
1172
+ "labels": batch["completion_labels"],
1173
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1174
+ }
1175
+ if self.is_encoder_decoder
1176
+ else {}
1177
+ )
1178
+ if self.aux_loss_enabled:
1179
+ model_kwargs["output_router_logits"] = True
1180
+
1181
+ outputs = model(
1182
+ batch["completion_input_ids"],
1183
+ attention_mask=batch["completion_attention_mask"],
1184
+ **model_kwargs,
1185
+ )
1186
+ completion_logits = outputs.logits
1187
+
1188
+ completion_logps = self.get_batch_logps(
1189
+ completion_logits,
1190
+ batch["completion_labels"],
1191
+ average_log_prob=False,
1192
+ is_encoder_decoder=self.is_encoder_decoder,
1193
+ label_pad_token_id=self.label_pad_token_id,
1194
+ )
1195
+
1196
+ if completion_logps.shape[0] != len(batch["label"]):
1197
+ raise ValueError(
1198
+ "There is a mismatch between the number of examples in this batch and the number of "
1199
+ "examples for which an output sequence was predicted."
1200
+ )
1201
+
1202
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1203
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1204
+
1205
+ chosen_logps = completion_logps[chosen_idx, ...]
1206
+ rejected_logps = completion_logps[rejected_idx, ...]
1207
+
1208
+ chosen_logits = completion_logits[chosen_idx, ...]
1209
+ rejected_logits = completion_logits[rejected_idx, ...]
1210
+
1211
+ if self.aux_loss_enabled:
1212
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1213
+ else:
1214
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1215
+
1216
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1217
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1218
+ min_ratio = self.args.min_density_ratio
1219
+ max_ratio = self.args.max_density_ratio
1220
+
1221
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1222
+
1223
+ return weight
1224
+
1225
+ def bco_loss(
1226
+ self,
1227
+ policy_chosen_logps: torch.FloatTensor,
1228
+ policy_rejected_logps: torch.FloatTensor,
1229
+ reference_chosen_logps: torch.FloatTensor,
1230
+ reference_rejected_logps: torch.FloatTensor,
1231
+ chosen_embeddings: Optional[torch.FloatTensor],
1232
+ rejected_embeddings: Optional[torch.FloatTensor],
1233
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1234
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1235
+
1236
+ Args:
1237
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1238
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1239
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1240
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1241
+ chosen_embeddings: embeddings of desirable prompts
1242
+ rejected_embeddings: embeddings of undesirable prompts
1243
+
1244
+ Returns:
1245
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1246
+ The losses tensor contains the BCO loss for each example in the batch.
1247
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1248
+ The delta value contains the moving average of all implicit rewards.
1249
+ """
1250
+
1251
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1252
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1253
+ chosen_rewards = self.beta * chosen_logratios
1254
+ else:
1255
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1256
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1257
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1258
+
1259
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1260
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1261
+ rejected_rewards = self.beta * rejected_logratios
1262
+ else:
1263
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1264
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1265
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1266
+
1267
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1268
+ self.running.update(rewards)
1269
+ delta = self.running.mean
1270
+
1271
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1272
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1273
+
1274
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1275
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1276
+
1277
+ if self.match_underlying_distribution:
1278
+ chosen_weight = torch.ones_like(chosen_losses)
1279
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1280
+
1281
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1282
+ else:
1283
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1284
+
1285
+ return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1286
+
1287
+ def get_batch_loss_metrics(
1288
+ self,
1289
+ model,
1290
+ batch: dict[str, Union[list, torch.LongTensor]],
1291
+ ):
1292
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1293
+ metrics = {}
1294
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1295
+
1296
+ forward_output = self.forward(model, batch)
1297
+ (
1298
+ policy_chosen_logps,
1299
+ policy_rejected_logps,
1300
+ policy_chosen_logits,
1301
+ policy_rejected_logits,
1302
+ ) = forward_output[:4]
1303
+ if self.aux_loss_enabled:
1304
+ aux_loss = forward_output[4]
1305
+
1306
+ # if reference_logps in batch use them, otherwise use the reference model
1307
+ if "reference_logps" in batch:
1308
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1309
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1310
+
1311
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1312
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1313
+ else:
1314
+ with torch.no_grad():
1315
+ if self.ref_model is None:
1316
+ with self.null_ref_context():
1317
+ (
1318
+ reference_chosen_logps,
1319
+ reference_rejected_logps,
1320
+ _,
1321
+ _,
1322
+ ) = self.forward(self.model, batch)[:4]
1323
+ else:
1324
+ (
1325
+ reference_chosen_logps,
1326
+ reference_rejected_logps,
1327
+ _,
1328
+ _,
1329
+ ) = self.forward(self.ref_model, batch)[:4]
1330
+
1331
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1332
+
1333
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1334
+ policy_chosen_logps,
1335
+ policy_rejected_logps,
1336
+ reference_chosen_logps,
1337
+ reference_rejected_logps,
1338
+ chosen_embeddings,
1339
+ rejected_embeddings,
1340
+ )
1341
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1342
+
1343
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1344
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1345
+
1346
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1347
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1348
+
1349
+ if all_num_chosen > 0:
1350
+ metrics["rewards/chosen_sum"] = (
1351
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1352
+ )
1353
+ metrics["logps/chosen_sum"] = (
1354
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1355
+ )
1356
+ metrics["logits/chosen_sum"] = (
1357
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1358
+ )
1359
+ metrics["count/chosen"] = all_num_chosen
1360
+
1361
+ if all_num_rejected > 0:
1362
+ metrics["rewards/rejected_sum"] = (
1363
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1364
+ )
1365
+ metrics["logps/rejected_sum"] = (
1366
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1367
+ )
1368
+ metrics["logits/rejected_sum"] = (
1369
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1370
+ )
1371
+ metrics["count/rejected"] = all_num_rejected
1372
+
1373
+ loss = losses.nanmean()
1374
+ if self.aux_loss_enabled:
1375
+ loss += self.aux_loss_coef * aux_loss
1376
+
1377
+ return loss, metrics
1378
+
1379
+ def compute_loss(
1380
+ self,
1381
+ model: Union[PreTrainedModel, nn.Module],
1382
+ inputs: dict[str, Union[torch.Tensor, Any]],
1383
+ return_outputs=False,
1384
+ num_items_in_batch=None,
1385
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1386
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1387
+
1388
+ with compute_loss_context_manager:
1389
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1390
+
1391
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1392
+ loss = loss.to(self.args.device)
1393
+ # force log the metrics
1394
+ if self.accelerator.is_main_process:
1395
+ self.store_metrics(metrics, train_eval="train")
1396
+
1397
+ if return_outputs:
1398
+ return (loss, metrics)
1399
+ return loss
1400
+
1401
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1402
+ for key, value in metrics.items():
1403
+ self._stored_metrics[train_eval][key].append(value)
1404
+
1405
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1406
+ if self.train_dataset is None or not has_length(self.train_dataset):
1407
+ return None
1408
+ return SequentialSampler(self.train_dataset)
1409
+
1410
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1411
+ """Generate samples from the model and reference model for the given batch of inputs."""
1412
+
1413
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1414
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1415
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1416
+ with generate_context_manager:
1417
+ policy_output = model.generate(
1418
+ input_ids=batch["prompt_input_ids"],
1419
+ attention_mask=batch["prompt_attention_mask"],
1420
+ max_length=self.max_length,
1421
+ do_sample=True,
1422
+ pad_token_id=self.processing_class.pad_token_id,
1423
+ )
1424
+
1425
+ # if reference_output in batch use that otherwise use the reference model
1426
+ if "reference_output" in batch:
1427
+ reference_output = batch["reference_output"]
1428
+ else:
1429
+ if self.ref_model is None:
1430
+ with self.null_ref_context():
1431
+ reference_output = self.model.generate(
1432
+ input_ids=batch["prompt_input_ids"],
1433
+ attention_mask=batch["prompt_attention_mask"],
1434
+ max_length=self.max_length,
1435
+ do_sample=True,
1436
+ pad_token_id=self.processing_class.pad_token_id,
1437
+ )
1438
+ else:
1439
+ reference_output = self.ref_model.generate(
1440
+ input_ids=batch["prompt_input_ids"],
1441
+ attention_mask=batch["prompt_attention_mask"],
1442
+ max_length=self.max_length,
1443
+ do_sample=True,
1444
+ pad_token_id=self.processing_class.pad_token_id,
1445
+ )
1446
+
1447
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1448
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1449
+
1450
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1451
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1452
+
1453
+ return policy_output_decoded, reference_output_decoded
1454
+
1455
+ def prediction_step(
1456
+ self,
1457
+ model: Union[PreTrainedModel, nn.Module],
1458
+ inputs: dict[str, Union[torch.Tensor, Any]],
1459
+ prediction_loss_only: bool,
1460
+ ignore_keys: Optional[list[str]] = None,
1461
+ ):
1462
+ if ignore_keys is None:
1463
+ if hasattr(model, "config"):
1464
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1465
+ else:
1466
+ ignore_keys = []
1467
+
1468
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1469
+ with torch.no_grad(), prediction_context_manager:
1470
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1471
+
1472
+ # force log the metrics
1473
+ if self.accelerator.is_main_process:
1474
+ self.store_metrics(metrics, train_eval="eval")
1475
+
1476
+ if prediction_loss_only:
1477
+ return (loss.detach(), None, None)
1478
+
1479
+ # logits for the chosen and rejected samples from model
1480
+ logits_dict = {
1481
+ "eval_logits/chosen": metrics["logits/chosen"],
1482
+ "eval_logits/rejected": metrics["logits/rejected"],
1483
+ }
1484
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1485
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1486
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1487
+
1488
+ return (loss.detach(), logits, labels)
1489
+
1490
+ def evaluation_loop(
1491
+ self,
1492
+ dataloader: DataLoader,
1493
+ description: str,
1494
+ prediction_loss_only: Optional[bool] = None,
1495
+ ignore_keys: Optional[list[str]] = None,
1496
+ metric_key_prefix: str = "eval",
1497
+ ) -> EvalLoopOutput:
1498
+ """
1499
+ Overriding built-in evaluation loop to store metrics for each batch.
1500
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1501
+
1502
+ Works both with or without labels.
1503
+ """
1504
+
1505
+ # Sample and save to game log if requested (for one batch to save time)
1506
+ if self.generate_during_eval:
1507
+ # Generate random indices within the range of the total number of samples
1508
+ num_samples = len(dataloader.dataset)
1509
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1510
+
1511
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1512
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1513
+ random_batch = self.data_collator(random_batch_dataset)
1514
+ random_batch = self._prepare_inputs(random_batch)
1515
+
1516
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1517
+ target_batch = {
1518
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1519
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1520
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1521
+ }
1522
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1523
+
1524
+ table = pd.DataFrame(
1525
+ columns=["Prompt", "Policy", "Ref Model"],
1526
+ data=[
1527
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1528
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1529
+ ],
1530
+ )
1531
+ if "wandb" in self.args.report_to:
1532
+ wandb.log({"game_log": wandb.Table(data=table)})
1533
+
1534
+ if "comet_ml" in self.args.report_to:
1535
+ log_table_to_comet_experiment(
1536
+ name="game_log.csv",
1537
+ table=table,
1538
+ )
1539
+
1540
+ # Base evaluation
1541
+ initial_output = super().evaluation_loop(
1542
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1543
+ )
1544
+
1545
+ return initial_output
1546
+
1547
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1548
+ """
1549
+ Log `logs` on the various objects watching training, including stored metrics.
1550
+
1551
+ Args:
1552
+ logs (`dict[str, float]`):
1553
+ The values to log.
1554
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1555
+ Start time of the training.
1556
+ """
1557
+ # logs either has 'loss' or 'eval_loss'
1558
+ train_eval = "train" if "loss" in logs else "eval"
1559
+ # train metrics should have no prefix, eval should have 'eval_'
1560
+ prefix = "eval_" if train_eval == "eval" else ""
1561
+ # accumulate average metrics from sums and lengths
1562
+ for split in ["chosen", "rejected"]:
1563
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1564
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1565
+ for metric in ["rewards", "logps", "logits"]:
1566
+ logs[f"{prefix}{metric}/{split}"] = (
1567
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1568
+ / count_sum
1569
+ )
1570
+ # delete obsolete metric
1571
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1572
+ del self._stored_metrics[train_eval][f"count/{split}"]
1573
+ # calculate reward margin
1574
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1575
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1576
+ # Add averaged stored metrics to logs
1577
+ for key, metrics in self._stored_metrics[train_eval].items():
1578
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1579
+ del self._stored_metrics[train_eval]
1580
+
1581
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1582
+ return super().log(logs, start_time)
1583
+ else: # transformers<=4.46
1584
+ return super().log(logs)
1585
+
1586
+ def create_model_card(
1587
+ self,
1588
+ model_name: Optional[str] = None,
1589
+ dataset_name: Optional[str] = None,
1590
+ tags: Union[str, list[str], None] = None,
1591
+ ):
1592
+ """
1593
+ Creates a draft of a model card using the information available to the `Trainer`.
1594
+
1595
+ Args:
1596
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1597
+ Name of the model.
1598
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1599
+ Name of the dataset used for training.
1600
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1601
+ Tags to be associated with the model card.
1602
+ """
1603
+ if not self.is_world_process_zero():
1604
+ return
1605
+
1606
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1607
+ base_model = self.model.config._name_or_path
1608
+ else:
1609
+ base_model = None
1610
+
1611
+ tags = tags or []
1612
+ if isinstance(tags, str):
1613
+ tags = [tags]
1614
+
1615
+ if hasattr(self.model.config, "unsloth_version"):
1616
+ tags.append("unsloth")
1617
+
1618
+ citation = textwrap.dedent("""\
1619
+ @article{jung2024binary,
1620
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1621
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1622
+ year = 2024,
1623
+ eprint = {arXiv:2404.04656}
1624
+ }""")
1625
+
1626
+ model_card = generate_model_card(
1627
+ base_model=base_model,
1628
+ model_name=model_name,
1629
+ hub_model_id=self.hub_model_id,
1630
+ dataset_name=dataset_name,
1631
+ tags=tags,
1632
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1633
+ comet_url=get_comet_experiment_url(),
1634
+ trainer_name="BCO",
1635
+ trainer_citation=citation,
1636
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1637
+ paper_id="2404.04656",
1638
+ )
1639
+
1640
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1641
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1642
+ """
1643
+
1644
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1645
+
1646
+ Args:
1647
+ model (`transformers.PreTrainedModel`):
1648
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1649
+ ref_model (`PreTrainedModelWrapper`):
1650
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1651
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1652
+ args (`BCOConfig`):
1653
+ The arguments to use for training.
1654
+ train_dataset (`datasets.Dataset`):
1655
+ The dataset to use for training.
1656
+ eval_dataset (`datasets.Dataset`):
1657
+ The dataset to use for evaluation.
1658
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1659
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1660
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1661
+ reuse the fine-tuned model.
1662
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1663
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1664
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1665
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1666
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1667
+ callbacks (`list[transformers.TrainerCallback]`):
1668
+ The callbacks to use for training.
1669
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1670
+ The optimizer and scheduler to use for training.
1671
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1672
+ The function to use to preprocess the logits before computing the metrics.
1673
+ peft_config (`dict`, defaults to `None`):
1674
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1675
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1676
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1677
+ a dictionary string to metric values.
1678
+ model_adapter_name (`str`, defaults to `None`):
1679
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1680
+ ref_adapter_name (`str`, defaults to `None`):
1681
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1682
+
1683
+ """
1684
+ def __init__(
1685
+ self,
1686
+ model = None,
1687
+ ref_model = None,
1688
+ args = None,
1689
+ train_dataset = None,
1690
+ eval_dataset = None,
1691
+ processing_class = None,
1692
+ data_collator = None,
1693
+ model_init = None,
1694
+ callbacks = None,
1695
+ preprocess_logits_for_metrics = None,
1696
+ peft_config = None,
1697
+ compute_metrics = None,
1698
+ model_adapter_name = None,
1699
+ ref_adapter_name = None,
1700
+ embedding_func = None,
1701
+ embedding_tokenizer = None,
1702
+ **kwargs
1703
+ ):
1704
+ if args is None: args = UnslothBCOConfig()
1705
+ use_bf16 = getattr(args, 'bf16', False)
1706
+ use_fp16 = getattr(args, 'fp16', False)
1707
+ force_float32 = False
1708
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1709
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1710
+ force_float32 = True
1711
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1712
+ dtype = getattr(model.config, 'torch_dtype', None)
1713
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1714
+ from unsloth_zoo.utils import _get_dtype
1715
+ dtype = _get_dtype(dtype)
1716
+ float16 = dtype == torch.float16
1717
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1718
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1719
+ if force_float32:
1720
+ args.fp16 = False
1721
+ args.bf16 = False
1722
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1723
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1724
+ args.fp16 = float16
1725
+ args.bf16 = not float16
1726
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1727
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1728
+ args.eval_strategy = 'steps'
1729
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1730
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1731
+ if ga_steps is not None and ga_steps > 1:
1732
+ from transformers import __version__ as transformers_version
1733
+ if Version(transformers_version) <= Version('4.45.2'):
1734
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1735
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1736
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1737
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1738
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1739
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1740
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1741
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1742
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1743
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1744
+ if force_float32:
1745
+ args.bf16_full_eval = False
1746
+ args.fp16_full_eval = False
1747
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1748
+ args.bf16_full_eval = True
1749
+ args.fp16_full_eval = False
1750
+ elif not bf16_full_eval and not fp16_full_eval:
1751
+ args.bf16_full_eval = args.bf16
1752
+ args.fp16_full_eval = args.fp16
1753
+ _output_logits = False
1754
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1755
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1756
+ if _output_logits:
1757
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1758
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1759
+ pass
1760
+ else:
1761
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1762
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1763
+ if args_max_seq_length is None and model_max_seq_length is not None:
1764
+ max_seq_length = model.max_seq_length
1765
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1766
+ if model is not None and hasattr(model, 'for_training'):
1767
+ model.for_training()
1768
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1769
+ if 'processing_class' in locals():
1770
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1771
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1772
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1773
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1774
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1775
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1776
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1777
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1778
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1779
+ else:
1780
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1781
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1782
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1783
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1784
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1785
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1786
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1787
+ else:
1788
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1789
+ other_metrics = []
1790
+
1791
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1792
+ PatchRLStatistics('bco_trainer', other_metrics)
1793
+
1794
+ super().__init__(
1795
+ model = model,
1796
+ ref_model = ref_model,
1797
+ args = args,
1798
+ train_dataset = train_dataset,
1799
+ eval_dataset = eval_dataset,
1800
+ processing_class = processing_class,
1801
+ data_collator = data_collator,
1802
+ model_init = model_init,
1803
+ callbacks = callbacks,
1804
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1805
+ peft_config = peft_config,
1806
+ compute_metrics = compute_metrics,
1807
+ model_adapter_name = model_adapter_name,
1808
+ ref_adapter_name = ref_adapter_name,
1809
+ embedding_func = embedding_func,
1810
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1811
+ if hasattr(self, 'neftune_hook_handle'):
1812
+ self.neftune_hook_handle.remove()
1813
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1814
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1815
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1816
+ pass
1817
+
1818
+ pass
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothCPOConfig(CPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`CPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
67
+ the [paper](https://huggingface.co/papers/2310.12036).
68
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
69
+ Label smoothing factor. This argument is required if you want to use the default data collator.
70
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
71
+ Type of loss to use. Possible values are:
72
+
73
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
74
+ - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
75
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
76
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
77
+
78
+ disable_dropout (`bool`, *optional*, defaults to `True`):
79
+ Whether to disable dropout in the model.
80
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
81
+ Weight of the BC regularizer in CPO training.
82
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
83
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
84
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
85
+ Label pad token id. This argument is required if you want to use the default data collator.
86
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
87
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
88
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
89
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
90
+ This argument is required if you want to use the default data collator.
91
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
92
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
93
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
94
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
95
+ you need to specify if the model returned by the callable is an encoder-decoder model.
96
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
98
+ string.
99
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+
102
+ """
103
+ vllm_sampling_params: Optional[Any] = field(
104
+ default = None,
105
+ metadata = {'help': 'vLLM SamplingParams'},
106
+ )
107
+ unsloth_num_chunks : Optional[int] = field(
108
+ default = -1,
109
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
+ )
111
+ def __init__(
112
+ self,
113
+ output_dir = None,
114
+ overwrite_output_dir = None,
115
+ do_train = False,
116
+ do_eval = False,
117
+ do_predict = False,
118
+ eval_strategy = 'no',
119
+ prediction_loss_only = False,
120
+ per_device_train_batch_size = 4,
121
+ per_device_eval_batch_size = 4,
122
+ per_gpu_train_batch_size = None,
123
+ per_gpu_eval_batch_size = None,
124
+ gradient_accumulation_steps = 2,
125
+ eval_accumulation_steps = 2,
126
+ eval_delay = 0,
127
+ torch_empty_cache_steps = 250,
128
+ learning_rate = 5e-05,
129
+ weight_decay = 0.01,
130
+ adam_beta1 = 0.9,
131
+ adam_beta2 = 0.999,
132
+ adam_epsilon = 1e-08,
133
+ max_grad_norm = 1.0,
134
+ num_train_epochs = 3.0,
135
+ max_steps = -1,
136
+ lr_scheduler_type = 'linear',
137
+ warmup_ratio = 0.1,
138
+ warmup_steps = 0,
139
+ log_level = 'passive',
140
+ log_level_replica = 'warning',
141
+ log_on_each_node = True,
142
+ logging_dir = None,
143
+ logging_strategy = 'steps',
144
+ logging_first_step = False,
145
+ logging_steps = 1,
146
+ logging_nan_inf_filter = False,
147
+ save_strategy = 'steps',
148
+ save_steps = 500,
149
+ save_total_limit = None,
150
+ save_safetensors = True,
151
+ save_on_each_node = False,
152
+ save_only_model = False,
153
+ restore_callback_states_from_checkpoint = False,
154
+ no_cuda = False,
155
+ use_cpu = False,
156
+ use_mps_device = False,
157
+ seed = 3407,
158
+ data_seed = 3407,
159
+ jit_mode_eval = False,
160
+ use_ipex = False,
161
+ bf16 = False,
162
+ fp16 = False,
163
+ fp16_opt_level = 'O1',
164
+ half_precision_backend = 'auto',
165
+ bf16_full_eval = False,
166
+ fp16_full_eval = False,
167
+ tf32 = None,
168
+ local_rank = -1,
169
+ ddp_backend = None,
170
+ tpu_num_cores = None,
171
+ tpu_metrics_debug = False,
172
+ debug = '',
173
+ dataloader_drop_last = False,
174
+ eval_steps = None,
175
+ dataloader_num_workers = 0,
176
+ dataloader_prefetch_factor = None,
177
+ past_index = -1,
178
+ run_name = None,
179
+ disable_tqdm = None,
180
+ remove_unused_columns = True,
181
+ label_names = None,
182
+ load_best_model_at_end = False,
183
+ metric_for_best_model = None,
184
+ greater_is_better = None,
185
+ ignore_data_skip = False,
186
+ fsdp = '',
187
+ fsdp_min_num_params = 0,
188
+ fsdp_config = None,
189
+ tp_size = 0,
190
+ fsdp_transformer_layer_cls_to_wrap = None,
191
+ accelerator_config = None,
192
+ deepspeed = None,
193
+ label_smoothing_factor = 0.0,
194
+ optim = 'adamw_8bit',
195
+ optim_args = None,
196
+ adafactor = False,
197
+ group_by_length = False,
198
+ length_column_name = 'length',
199
+ report_to = None,
200
+ ddp_find_unused_parameters = None,
201
+ ddp_bucket_cap_mb = None,
202
+ ddp_broadcast_buffers = None,
203
+ dataloader_pin_memory = True,
204
+ dataloader_persistent_workers = False,
205
+ skip_memory_metrics = True,
206
+ use_legacy_prediction_loop = False,
207
+ push_to_hub = False,
208
+ resume_from_checkpoint = None,
209
+ hub_model_id = None,
210
+ hub_strategy = 'every_save',
211
+ hub_token = None,
212
+ hub_private_repo = None,
213
+ hub_always_push = False,
214
+ gradient_checkpointing = False,
215
+ gradient_checkpointing_kwargs = None,
216
+ include_inputs_for_metrics = False,
217
+ eval_do_concat_batches = True,
218
+ fp16_backend = 'auto',
219
+ push_to_hub_model_id = None,
220
+ push_to_hub_organization = None,
221
+ push_to_hub_token = None,
222
+ mp_parameters = '',
223
+ auto_find_batch_size = False,
224
+ full_determinism = False,
225
+ torchdynamo = None,
226
+ ray_scope = 'last',
227
+ ddp_timeout = 1800,
228
+ torch_compile = False,
229
+ torch_compile_backend = None,
230
+ torch_compile_mode = None,
231
+ include_tokens_per_second = False,
232
+ include_num_input_tokens_seen = False,
233
+ neftune_noise_alpha = None,
234
+ optim_target_modules = None,
235
+ batch_eval_metrics = False,
236
+ eval_on_start = False,
237
+ use_liger_kernel = False,
238
+ eval_use_gather_object = False,
239
+ average_tokens_across_devices = False,
240
+ max_length = 1024,
241
+ max_prompt_length = 512,
242
+ max_completion_length = None,
243
+ beta = 0.1,
244
+ label_smoothing = 0.0,
245
+ loss_type = 'sigmoid',
246
+ disable_dropout = True,
247
+ cpo_alpha = 1.0,
248
+ simpo_gamma = 0.5,
249
+ label_pad_token_id = -100,
250
+ padding_value = None,
251
+ truncation_mode = 'keep_end',
252
+ generate_during_eval = False,
253
+ is_encoder_decoder = None,
254
+ model_init_kwargs = None,
255
+ dataset_num_proc = None,
256
+ vllm_sampling_params = None,
257
+ unsloth_num_chunks = -1,
258
+ **kwargs,
259
+ ):
260
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
261
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
262
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
263
+ output_dir = 'unsloth_training_checkpoints'
264
+ save_strategy = 'no'
265
+ if dataset_num_proc is None:
266
+ from multiprocessing import cpu_count
267
+ dataset_num_proc = cpu_count()
268
+
269
+ super().__init__(
270
+ output_dir = output_dir,
271
+ overwrite_output_dir = overwrite_output_dir,
272
+ do_train = do_train,
273
+ do_eval = do_eval,
274
+ do_predict = do_predict,
275
+ eval_strategy = eval_strategy,
276
+ prediction_loss_only = prediction_loss_only,
277
+ per_device_train_batch_size = per_device_train_batch_size,
278
+ per_device_eval_batch_size = per_device_eval_batch_size,
279
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
280
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
281
+ gradient_accumulation_steps = gradient_accumulation_steps,
282
+ eval_accumulation_steps = eval_accumulation_steps,
283
+ eval_delay = eval_delay,
284
+ torch_empty_cache_steps = torch_empty_cache_steps,
285
+ learning_rate = learning_rate,
286
+ weight_decay = weight_decay,
287
+ adam_beta1 = adam_beta1,
288
+ adam_beta2 = adam_beta2,
289
+ adam_epsilon = adam_epsilon,
290
+ max_grad_norm = max_grad_norm,
291
+ num_train_epochs = num_train_epochs,
292
+ max_steps = max_steps,
293
+ lr_scheduler_type = lr_scheduler_type,
294
+ warmup_ratio = warmup_ratio,
295
+ warmup_steps = warmup_steps,
296
+ log_level = log_level,
297
+ log_level_replica = log_level_replica,
298
+ log_on_each_node = log_on_each_node,
299
+ logging_dir = logging_dir,
300
+ logging_strategy = logging_strategy,
301
+ logging_first_step = logging_first_step,
302
+ logging_steps = logging_steps,
303
+ logging_nan_inf_filter = logging_nan_inf_filter,
304
+ save_strategy = save_strategy,
305
+ save_steps = save_steps,
306
+ save_total_limit = save_total_limit,
307
+ save_safetensors = save_safetensors,
308
+ save_on_each_node = save_on_each_node,
309
+ save_only_model = save_only_model,
310
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
311
+ no_cuda = no_cuda,
312
+ use_cpu = use_cpu,
313
+ use_mps_device = use_mps_device,
314
+ seed = seed,
315
+ data_seed = data_seed,
316
+ jit_mode_eval = jit_mode_eval,
317
+ use_ipex = use_ipex,
318
+ bf16 = bf16,
319
+ fp16 = fp16,
320
+ fp16_opt_level = fp16_opt_level,
321
+ half_precision_backend = half_precision_backend,
322
+ bf16_full_eval = bf16_full_eval,
323
+ fp16_full_eval = fp16_full_eval,
324
+ tf32 = tf32,
325
+ local_rank = local_rank,
326
+ ddp_backend = ddp_backend,
327
+ tpu_num_cores = tpu_num_cores,
328
+ tpu_metrics_debug = tpu_metrics_debug,
329
+ debug = debug,
330
+ dataloader_drop_last = dataloader_drop_last,
331
+ eval_steps = eval_steps,
332
+ dataloader_num_workers = dataloader_num_workers,
333
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
334
+ past_index = past_index,
335
+ run_name = run_name,
336
+ disable_tqdm = disable_tqdm,
337
+ remove_unused_columns = remove_unused_columns,
338
+ label_names = label_names,
339
+ load_best_model_at_end = load_best_model_at_end,
340
+ metric_for_best_model = metric_for_best_model,
341
+ greater_is_better = greater_is_better,
342
+ ignore_data_skip = ignore_data_skip,
343
+ fsdp = fsdp,
344
+ fsdp_min_num_params = fsdp_min_num_params,
345
+ fsdp_config = fsdp_config,
346
+ tp_size = tp_size,
347
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
348
+ accelerator_config = accelerator_config,
349
+ deepspeed = deepspeed,
350
+ label_smoothing_factor = label_smoothing_factor,
351
+ optim = optim,
352
+ optim_args = optim_args,
353
+ adafactor = adafactor,
354
+ group_by_length = group_by_length,
355
+ length_column_name = length_column_name,
356
+ report_to = report_to,
357
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
358
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
359
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
360
+ dataloader_pin_memory = dataloader_pin_memory,
361
+ dataloader_persistent_workers = dataloader_persistent_workers,
362
+ skip_memory_metrics = skip_memory_metrics,
363
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
364
+ push_to_hub = push_to_hub,
365
+ resume_from_checkpoint = resume_from_checkpoint,
366
+ hub_model_id = hub_model_id,
367
+ hub_strategy = hub_strategy,
368
+ hub_token = hub_token,
369
+ hub_private_repo = hub_private_repo,
370
+ hub_always_push = hub_always_push,
371
+ gradient_checkpointing = gradient_checkpointing,
372
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
373
+ include_inputs_for_metrics = include_inputs_for_metrics,
374
+ eval_do_concat_batches = eval_do_concat_batches,
375
+ fp16_backend = fp16_backend,
376
+ push_to_hub_model_id = push_to_hub_model_id,
377
+ push_to_hub_organization = push_to_hub_organization,
378
+ push_to_hub_token = push_to_hub_token,
379
+ mp_parameters = mp_parameters,
380
+ auto_find_batch_size = auto_find_batch_size,
381
+ full_determinism = full_determinism,
382
+ torchdynamo = torchdynamo,
383
+ ray_scope = ray_scope,
384
+ ddp_timeout = ddp_timeout,
385
+ torch_compile = torch_compile,
386
+ torch_compile_backend = torch_compile_backend,
387
+ torch_compile_mode = torch_compile_mode,
388
+ include_tokens_per_second = include_tokens_per_second,
389
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
390
+ neftune_noise_alpha = neftune_noise_alpha,
391
+ optim_target_modules = optim_target_modules,
392
+ batch_eval_metrics = batch_eval_metrics,
393
+ eval_on_start = eval_on_start,
394
+ use_liger_kernel = use_liger_kernel,
395
+ eval_use_gather_object = eval_use_gather_object,
396
+ average_tokens_across_devices = average_tokens_across_devices,
397
+ max_length = max_length,
398
+ max_prompt_length = max_prompt_length,
399
+ max_completion_length = max_completion_length,
400
+ beta = beta,
401
+ label_smoothing = label_smoothing,
402
+ loss_type = loss_type,
403
+ disable_dropout = disable_dropout,
404
+ cpo_alpha = cpo_alpha,
405
+ simpo_gamma = simpo_gamma,
406
+ label_pad_token_id = label_pad_token_id,
407
+ padding_value = padding_value,
408
+ truncation_mode = truncation_mode,
409
+ generate_during_eval = generate_during_eval,
410
+ is_encoder_decoder = is_encoder_decoder,
411
+ model_init_kwargs = model_init_kwargs,
412
+ dataset_num_proc = dataset_num_proc,**kwargs)
413
+ self.vllm_sampling_params = vllm_sampling_params
414
+ self.unsloth_num_chunks = unsloth_num_chunks
415
+ pass
416
+
417
+ class _UnslothCPOTrainer(Trainer):
418
+ r""""""
419
+
420
+ _tag_names = ["trl", "cpo"]
421
+
422
+ def __init__(
423
+ self,
424
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
425
+ args: Optional[CPOConfig] = None,
426
+ data_collator: Optional[DataCollator] = None,
427
+ train_dataset: Optional[Dataset] = None,
428
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
429
+ processing_class: Optional[
430
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
431
+ ] = None,
432
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
433
+ callbacks: Optional[list[TrainerCallback]] = None,
434
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
435
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
436
+ peft_config: Optional[dict] = None,
437
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
438
+ ):
439
+ if args.model_init_kwargs is None:
440
+ model_init_kwargs = {}
441
+ elif not isinstance(model, str):
442
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
443
+ else:
444
+ model_init_kwargs = args.model_init_kwargs
445
+ torch_dtype = model_init_kwargs.get("torch_dtype")
446
+ if torch_dtype is not None:
447
+ # Convert to `torch.dtype` if an str is passed
448
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
449
+ torch_dtype = getattr(torch, torch_dtype)
450
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
451
+ raise ValueError(
452
+ f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
453
+ )
454
+ model_init_kwargs["torch_dtype"] = torch_dtype
455
+
456
+ if isinstance(model, str):
457
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
458
+
459
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
460
+ # has been called in order to properly call autocast if needed.
461
+ self._peft_has_been_casted_to_bf16 = False
462
+
463
+ if not is_peft_available() and peft_config is not None:
464
+ raise ValueError(
465
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
466
+ )
467
+ elif is_peft_available() and peft_config is not None:
468
+ # if model is a peft model and we have a peft_config, we merge and unload it first
469
+ if isinstance(model, PeftModel):
470
+ model = model.merge_and_unload()
471
+
472
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
473
+ _support_gc_kwargs = hasattr(
474
+ args, "gradient_checkpointing_kwargs"
475
+ ) and "gradient_checkpointing_kwargs" in list(
476
+ inspect.signature(prepare_model_for_kbit_training).parameters
477
+ )
478
+
479
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
480
+
481
+ if _support_gc_kwargs:
482
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
483
+
484
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
485
+ elif getattr(args, "gradient_checkpointing", False):
486
+ # For backward compatibility with older versions of transformers
487
+ if hasattr(model, "enable_input_require_grads"):
488
+ model.enable_input_require_grads()
489
+ else:
490
+
491
+ def make_inputs_require_grad(module, input, output):
492
+ output.requires_grad_(True)
493
+
494
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
495
+
496
+ # get peft model with the given config
497
+ model = model
498
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
499
+ peft_module_casting_to_bf16(model)
500
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
501
+ self._peft_has_been_casted_to_bf16 = True
502
+
503
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
504
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
505
+ # fail or completely fail.
506
+ elif getattr(args, "gradient_checkpointing", False):
507
+ # For backward compatibility with older versions of transformers
508
+ if hasattr(model, "enable_input_require_grads"):
509
+ model.enable_input_require_grads()
510
+ else:
511
+
512
+ def make_inputs_require_grad(module, input, output):
513
+ output.requires_grad_(True)
514
+
515
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
516
+
517
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
518
+ raise ValueError(
519
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
520
+ " Please install `wandb` or `comet-ml` to resolve."
521
+ )
522
+
523
+ if model is not None:
524
+ self.is_encoder_decoder = model.config.is_encoder_decoder
525
+ elif args.is_encoder_decoder is None:
526
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
527
+ else:
528
+ self.is_encoder_decoder = args.is_encoder_decoder
529
+
530
+ if self.is_encoder_decoder:
531
+ self.decoder_start_token_id = model.config.decoder_start_token_id
532
+ self.pad_token_id = model.config.pad_token_id
533
+
534
+ if processing_class is None:
535
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
536
+ if args.max_length is None:
537
+ warnings.warn(
538
+ "`max_length` is not set in the CPOConfig's init"
539
+ " it will default to `512` by default, but you should do it yourself in the future.",
540
+ UserWarning,
541
+ )
542
+ max_length = 512
543
+ else:
544
+ max_length = args.max_length
545
+ if args.max_prompt_length is None:
546
+ warnings.warn(
547
+ "`max_prompt_length` is not set in the CPOConfig's init"
548
+ " it will default to `128` by default, but you should do it yourself in the future.",
549
+ UserWarning,
550
+ )
551
+ max_prompt_length = 128
552
+ else:
553
+ max_prompt_length = args.max_prompt_length
554
+
555
+ if args.max_completion_length is None and self.is_encoder_decoder:
556
+ warnings.warn(
557
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
558
+ " it will default to `128` by default, but you should do it yourself in the future.",
559
+ UserWarning,
560
+ )
561
+ max_completion_length = 128
562
+ else:
563
+ max_completion_length = args.max_completion_length
564
+
565
+ if data_collator is None:
566
+ data_collator = DPODataCollatorWithPadding(
567
+ pad_token_id=processing_class.pad_token_id,
568
+ label_pad_token_id=args.label_pad_token_id,
569
+ is_encoder_decoder=self.is_encoder_decoder,
570
+ )
571
+
572
+ if args.remove_unused_columns:
573
+ args.remove_unused_columns = False
574
+ # warn users
575
+ warnings.warn(
576
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
577
+ " we have set it for you, but you should do it yourself in the future.",
578
+ UserWarning,
579
+ )
580
+
581
+ self.use_dpo_data_collator = True
582
+ else:
583
+ self.use_dpo_data_collator = False
584
+
585
+ # Disable dropout in the model
586
+ if args.disable_dropout:
587
+ disable_dropout_in_model(model)
588
+
589
+ self.max_length = max_length
590
+ self.generate_during_eval = args.generate_during_eval
591
+ self.label_pad_token_id = args.label_pad_token_id
592
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
593
+ self.max_prompt_length = max_prompt_length
594
+ self.truncation_mode = args.truncation_mode
595
+ self.max_completion_length = max_completion_length
596
+ self.processing_class = processing_class
597
+
598
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
599
+ warnings.warn(
600
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
601
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
602
+ UserWarning,
603
+ )
604
+ if args.loss_type == "kto_pair":
605
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
606
+
607
+ self.beta = args.beta
608
+ self.label_smoothing = args.label_smoothing
609
+ self.loss_type = args.loss_type
610
+ self.cpo_alpha = args.cpo_alpha
611
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
612
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
613
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
614
+ warnings.warn(
615
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
616
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
617
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
618
+ "loss.",
619
+ UserWarning,
620
+ )
621
+
622
+ if args.loss_type == "simpo":
623
+ self.simpo_gamma = args.simpo_gamma
624
+
625
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
626
+
627
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
628
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
629
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
630
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
631
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
632
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
633
+ # that the warning has already been issued.
634
+ model.warnings_issued["estimate_tokens"] = True
635
+
636
+ # Compute that only on the main process for faster data processing.
637
+ # see: https://github.com/huggingface/trl/pull/1255
638
+ with PartialState().local_main_process_first():
639
+ # Extract the prompt if needed, and apply the chat template if needed
640
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
641
+ train_dataset = train_dataset.map(
642
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
643
+ )
644
+ if eval_dataset is not None:
645
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
646
+ eval_dataset = eval_dataset.map(
647
+ maybe_apply_chat_template,
648
+ fn_kwargs={"tokenizer": processing_class},
649
+ num_proc=args.dataset_num_proc,
650
+ )
651
+
652
+ # tokenize the dataset
653
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
654
+ if eval_dataset is not None:
655
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
656
+
657
+ super().__init__(
658
+ model=model,
659
+ args=args,
660
+ data_collator=data_collator,
661
+ train_dataset=train_dataset,
662
+ eval_dataset=eval_dataset,
663
+ processing_class=processing_class,
664
+ model_init=model_init,
665
+ compute_metrics=compute_metrics,
666
+ callbacks=callbacks,
667
+ optimizers=optimizers,
668
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
669
+ )
670
+
671
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
672
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
673
+ # self.model_accepts_loss_kwargs to False to enable scaling.
674
+ self.model_accepts_loss_kwargs = False
675
+
676
+ # Add tags for models that have been loaded with the correct transformers version
677
+ if hasattr(self.model, "add_model_tags"):
678
+ self.model.add_model_tags(self._tag_names)
679
+
680
+ if not hasattr(self, "accelerator"):
681
+ raise AttributeError(
682
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
683
+ )
684
+
685
+ def build_tokenized_answer(self, prompt, answer):
686
+ """
687
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
688
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
689
+ Reference:
690
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
691
+ """
692
+
693
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
694
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
695
+
696
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
697
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
698
+
699
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
700
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
701
+
702
+ # Prepare input tokens for token by token comparison
703
+ full_input_ids = np.array(full_tokenized["input_ids"])
704
+
705
+ if len(full_input_ids) != len(full_concat_input_ids):
706
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
707
+
708
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
709
+ # can be merged together when tokenizing prompt+answer. This could result
710
+ # on the last token from the prompt being different when tokenized on its own
711
+ # vs when done as prompt+answer.
712
+ response_token_ids_start_idx = len(prompt_input_ids)
713
+
714
+ # If tokenized prompt is different than both prompt+answer, then it means the
715
+ # last token has changed due to merging.
716
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
717
+ response_token_ids_start_idx -= 1
718
+
719
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
720
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
721
+
722
+ if len(prompt_input_ids) != len(prompt_attention_mask):
723
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
724
+
725
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
726
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
727
+
728
+ return dict(
729
+ prompt_input_ids=prompt_input_ids,
730
+ prompt_attention_mask=prompt_attention_mask,
731
+ input_ids=answer_input_ids,
732
+ attention_mask=answer_attention_mask,
733
+ )
734
+
735
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
736
+ """Tokenize a single row from a CPO specific dataset.
737
+
738
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
739
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
740
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
741
+
742
+ We also create the labels for the chosen/rejected responses, which are of length equal to
743
+ the sum of the length of the prompt and the chosen/rejected response, with
744
+ label_pad_token_id for the prompt tokens.
745
+ """
746
+ batch = {}
747
+ prompt = feature["prompt"]
748
+ chosen = feature["chosen"]
749
+ rejected = feature["rejected"]
750
+
751
+ if not self.is_encoder_decoder:
752
+ # Check issues below for more details
753
+ # 1. https://github.com/huggingface/trl/issues/907
754
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
755
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
756
+
757
+ if not isinstance(prompt, str):
758
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
759
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
760
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
761
+
762
+ if not isinstance(chosen, str):
763
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
764
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
765
+
766
+ if not isinstance(rejected, str):
767
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
768
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
769
+
770
+ # Last prompt token might get merged by tokenizer and
771
+ # it should not be included for generation if that happens
772
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
773
+
774
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
775
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
776
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
777
+
778
+ for k, v in prompt_tokens.items():
779
+ prompt_tokens[k] = v[:prompt_len_input_ids]
780
+
781
+ # Make sure prompts only have one different token at most an
782
+ # and length only differs by 1 at most
783
+ num_diff_tokens = sum(
784
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
785
+ )
786
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
787
+ if num_diff_tokens > 1 or num_diff_len > 1:
788
+ raise ValueError(
789
+ "Chosen and rejected prompt_input_ids might only differ on the "
790
+ "last token due to tokenizer merge ops."
791
+ )
792
+
793
+ # add BOS token to head of prompt. Avoid adding if it's already there
794
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
795
+ self.processing_class.bos_token_id,
796
+ prompt_len_input_ids,
797
+ prompt_tokens,
798
+ chosen_prompt_len_input_ids,
799
+ chosen_tokens,
800
+ rejected_prompt_len_input_ids,
801
+ rejected_tokens,
802
+ )
803
+
804
+ # add EOS token to end of answer. Avoid adding if it's already there
805
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
806
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
807
+ )
808
+
809
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
810
+
811
+ # if combined sequence is too long, truncate the prompt
812
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
813
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
814
+ if self.truncation_mode == "keep_start":
815
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
816
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
817
+ elif self.truncation_mode == "keep_end":
818
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
819
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
820
+ else:
821
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
822
+
823
+ # if that's still too long, truncate the response
824
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
825
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
826
+ for k in ["input_ids", "attention_mask"]:
827
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
828
+
829
+ # Create labels
830
+ chosen_sequence_tokens = {
831
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
832
+ }
833
+ rejected_sequence_tokens = {
834
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
835
+ }
836
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
837
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
838
+ self.label_pad_token_id
839
+ ] * len(chosen_tokens["prompt_input_ids"])
840
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
841
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
842
+ self.label_pad_token_id
843
+ ] * len(rejected_tokens["prompt_input_ids"])
844
+
845
+ for k, toks in {
846
+ "chosen_": chosen_sequence_tokens,
847
+ "rejected_": rejected_sequence_tokens,
848
+ "": prompt_tokens,
849
+ }.items():
850
+ for type_key, tokens in toks.items():
851
+ if type_key == "token_type_ids":
852
+ continue
853
+ batch[f"{k}{type_key}"] = tokens
854
+
855
+ else:
856
+ chosen_tokens = self.processing_class(
857
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
858
+ )
859
+ rejected_tokens = self.processing_class(
860
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
861
+ )
862
+ prompt_tokens = self.processing_class(
863
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
864
+ )
865
+
866
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
867
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
868
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
869
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
870
+
871
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
872
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
873
+ labels=torch.tensor(batch["rejected_labels"])
874
+ )
875
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
876
+ labels=torch.tensor(batch["chosen_labels"])
877
+ )
878
+
879
+ return batch
880
+
881
+ @staticmethod
882
+ def concatenated_inputs(
883
+ batch: dict[str, Union[list, torch.LongTensor]],
884
+ is_encoder_decoder: bool = False,
885
+ label_pad_token_id: int = -100,
886
+ padding_value: int = 0,
887
+ device: Optional[torch.device] = None,
888
+ ) -> dict[str, torch.LongTensor]:
889
+ """Concatenate the chosen and rejected inputs into a single tensor.
890
+
891
+ Args:
892
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
893
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
894
+ label_pad_token_id: The label pad token id.
895
+ padding_value: The padding value to use for the concatenated inputs_ids.
896
+ device: The device for the concatenated inputs.
897
+
898
+ Returns:
899
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
900
+ """
901
+ concatenated_batch = {}
902
+
903
+ if is_encoder_decoder:
904
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
905
+ else:
906
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
907
+
908
+ for k in batch:
909
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
910
+ if "labels" in k or is_encoder_decoder:
911
+ pad_value = label_pad_token_id
912
+ elif k.endswith("_input_ids"):
913
+ pad_value = padding_value
914
+ elif k.endswith("_attention_mask"):
915
+ pad_value = 0
916
+ concatenated_key = k.replace("chosen", "concatenated")
917
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
918
+ for k in batch:
919
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
920
+ if "labels" in k or is_encoder_decoder:
921
+ pad_value = label_pad_token_id
922
+ elif k.endswith("_input_ids"):
923
+ pad_value = padding_value
924
+ elif k.endswith("_attention_mask"):
925
+ pad_value = 0
926
+ concatenated_key = k.replace("rejected", "concatenated")
927
+ concatenated_batch[concatenated_key] = torch.cat(
928
+ (
929
+ concatenated_batch[concatenated_key],
930
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
931
+ ),
932
+ dim=0,
933
+ ).to(device=device)
934
+
935
+ if is_encoder_decoder:
936
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
937
+ concatenated_batch["concatenated_attention_mask"] = (
938
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
939
+ )
940
+
941
+ return concatenated_batch
942
+
943
+ def cpo_loss(
944
+ self,
945
+ policy_chosen_logps: torch.FloatTensor,
946
+ policy_rejected_logps: torch.FloatTensor,
947
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
948
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
949
+
950
+ Args:
951
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
952
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
953
+
954
+ Returns:
955
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
956
+ The losses tensor contains the CPO loss for each example in the batch.
957
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
958
+ """
959
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
960
+
961
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
962
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
963
+ # calculates a conservative CPO loss.
964
+
965
+ if self.loss_type == "simpo":
966
+ gamma_logratios = self.simpo_gamma / self.beta
967
+ logits = logits - gamma_logratios
968
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
969
+ losses = (
970
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
971
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
972
+ )
973
+ elif self.loss_type == "sigmoid":
974
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
975
+ losses = (
976
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
977
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
978
+ )
979
+ elif self.loss_type == "hinge":
980
+ losses = torch.relu(1 - self.beta * logits)
981
+ elif self.loss_type == "ipo":
982
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
983
+ losses = (logits - 1 / (2 * self.beta)) ** 2
984
+ else:
985
+ raise ValueError(
986
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
987
+ )
988
+
989
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
990
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
991
+
992
+ return losses, chosen_rewards, rejected_rewards
993
+
994
+ @staticmethod
995
+ def get_batch_logps(
996
+ logits: torch.FloatTensor,
997
+ labels: torch.LongTensor,
998
+ average_log_prob: bool = False,
999
+ label_pad_token_id: int = -100,
1000
+ is_encoder_decoder: bool = False,
1001
+ ) -> torch.FloatTensor:
1002
+ """Compute the log probabilities of the given labels under the given logits.
1003
+
1004
+ Args:
1005
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1006
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1007
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1008
+ label_pad_token_id: The label pad token id.
1009
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1010
+
1011
+ Returns:
1012
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1013
+ """
1014
+ if logits.shape[:-1] != labels.shape:
1015
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1016
+
1017
+ if not is_encoder_decoder:
1018
+ labels = labels[:, 1:].clone()
1019
+ logits = logits[:, :-1, :]
1020
+ loss_mask = labels != label_pad_token_id
1021
+
1022
+ # dummy token; we'll ignore the losses on these tokens later
1023
+ labels[labels == label_pad_token_id] = 0
1024
+
1025
+ per_token_logps = selective_log_softmax(logits, labels)
1026
+
1027
+ if average_log_prob:
1028
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1029
+ else:
1030
+ return (per_token_logps * loss_mask).sum(-1)
1031
+
1032
+ def concatenated_forward(
1033
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1034
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1035
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1036
+
1037
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1038
+ """
1039
+ concatenated_batch = self.concatenated_inputs(
1040
+ batch,
1041
+ is_encoder_decoder=self.is_encoder_decoder,
1042
+ label_pad_token_id=self.label_pad_token_id,
1043
+ padding_value=self.padding_value,
1044
+ device=self.accelerator.device,
1045
+ )
1046
+ len_chosen = batch["chosen_labels"].shape[0]
1047
+
1048
+ model_kwargs = (
1049
+ {
1050
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1051
+ }
1052
+ if self.is_encoder_decoder
1053
+ else {}
1054
+ )
1055
+
1056
+ if self.aux_loss_enabled:
1057
+ model_kwargs["output_router_logits"] = True
1058
+
1059
+ outputs = model(
1060
+ concatenated_batch["concatenated_input_ids"],
1061
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1062
+ use_cache=False,
1063
+ **model_kwargs,
1064
+ )
1065
+ all_logits = outputs.logits
1066
+
1067
+ def cross_entropy_loss(logits, labels):
1068
+ if not self.is_encoder_decoder:
1069
+ # Shift so that tokens < n predict n
1070
+ logits = logits[..., :-1, :].contiguous()
1071
+ labels = labels[..., 1:].contiguous()
1072
+ # Flatten the tokens
1073
+ loss_fct = nn.CrossEntropyLoss()
1074
+ logits = logits.view(-1, logits.shape[-1])
1075
+ labels = labels.view(-1)
1076
+ # Enable model parallelism
1077
+ labels = labels.to(logits.device)
1078
+ loss = loss_fct(logits, labels)
1079
+ return loss
1080
+
1081
+ labels = concatenated_batch["concatenated_labels"].clone()
1082
+
1083
+ if self.cpo_alpha == 0:
1084
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1085
+ else:
1086
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1087
+
1088
+ all_logps = self.get_batch_logps(
1089
+ all_logits,
1090
+ concatenated_batch["concatenated_labels"],
1091
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1092
+ is_encoder_decoder=self.is_encoder_decoder,
1093
+ label_pad_token_id=self.label_pad_token_id,
1094
+ )
1095
+
1096
+ chosen_logps = all_logps[:len_chosen]
1097
+ rejected_logps = all_logps[len_chosen:]
1098
+
1099
+ chosen_logits = all_logits[:len_chosen]
1100
+ rejected_logits = all_logits[len_chosen:]
1101
+
1102
+ if self.aux_loss_enabled:
1103
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1104
+
1105
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1106
+
1107
+ def get_batch_loss_metrics(
1108
+ self,
1109
+ model,
1110
+ batch: dict[str, Union[list, torch.LongTensor]],
1111
+ train_eval: Literal["train", "eval"] = "train",
1112
+ ):
1113
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1114
+ metrics = {}
1115
+
1116
+ forward_output = self.concatenated_forward(model, batch)
1117
+ (
1118
+ policy_chosen_logps,
1119
+ policy_rejected_logps,
1120
+ policy_chosen_logits,
1121
+ policy_rejected_logits,
1122
+ policy_nll_loss,
1123
+ ) = forward_output[:5]
1124
+ if self.aux_loss_enabled:
1125
+ aux_loss = forward_output[5]
1126
+
1127
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1128
+ policy_chosen_logps,
1129
+ policy_rejected_logps,
1130
+ )
1131
+
1132
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1133
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1134
+
1135
+ prefix = "eval_" if train_eval == "eval" else ""
1136
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1137
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1138
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1139
+ metrics[f"{prefix}rewards/margins"] = (
1140
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1141
+ )
1142
+ metrics[f"{prefix}logps/rejected"] = (
1143
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1144
+ )
1145
+ metrics[f"{prefix}logps/chosen"] = (
1146
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1147
+ )
1148
+ metrics[f"{prefix}logits/rejected"] = (
1149
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1150
+ )
1151
+ metrics[f"{prefix}logits/chosen"] = (
1152
+ self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1153
+ )
1154
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1155
+
1156
+ if self.aux_loss_enabled:
1157
+ loss += self.aux_loss_coef * aux_loss
1158
+
1159
+ return loss, metrics
1160
+
1161
+ def compute_loss(
1162
+ self,
1163
+ model: Union[PreTrainedModel, nn.Module],
1164
+ inputs: dict[str, Union[torch.Tensor, Any]],
1165
+ return_outputs=False,
1166
+ num_items_in_batch=None,
1167
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1168
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1169
+
1170
+ with compute_loss_context_manager:
1171
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1172
+
1173
+ # force log the metrics
1174
+ self.store_metrics(metrics, train_eval="train")
1175
+
1176
+ if return_outputs:
1177
+ return (loss, metrics)
1178
+ return loss
1179
+
1180
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1181
+ """Generate samples from the model and reference model for the given batch of inputs."""
1182
+
1183
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1184
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1185
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1186
+
1187
+ with generate_context_manager:
1188
+ policy_output = model.generate(
1189
+ input_ids=batch["prompt_input_ids"],
1190
+ attention_mask=batch["prompt_attention_mask"],
1191
+ max_length=self.max_length,
1192
+ do_sample=True,
1193
+ pad_token_id=self.processing_class.pad_token_id,
1194
+ )
1195
+
1196
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1197
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1198
+
1199
+ return policy_output_decoded
1200
+
1201
+ def prediction_step(
1202
+ self,
1203
+ model: Union[PreTrainedModel, nn.Module],
1204
+ inputs: dict[str, Union[torch.Tensor, Any]],
1205
+ prediction_loss_only: bool,
1206
+ ignore_keys: Optional[list[str]] = None,
1207
+ ):
1208
+ if ignore_keys is None:
1209
+ if hasattr(model, "config"):
1210
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1211
+ else:
1212
+ ignore_keys = []
1213
+
1214
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1215
+
1216
+ with torch.no_grad(), prediction_context_manager:
1217
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1218
+
1219
+ # force log the metrics
1220
+ self.store_metrics(metrics, train_eval="eval")
1221
+
1222
+ if prediction_loss_only:
1223
+ return (loss.detach(), None, None)
1224
+
1225
+ # logits for the chosen and rejected samples from model
1226
+ logits_dict = {
1227
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1228
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1229
+ }
1230
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1231
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1232
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1233
+
1234
+ return (loss.detach(), logits, labels)
1235
+
1236
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1237
+ for key, value in metrics.items():
1238
+ self._stored_metrics[train_eval][key].append(value)
1239
+
1240
+ def evaluation_loop(
1241
+ self,
1242
+ dataloader: DataLoader,
1243
+ description: str,
1244
+ prediction_loss_only: Optional[bool] = None,
1245
+ ignore_keys: Optional[list[str]] = None,
1246
+ metric_key_prefix: str = "eval",
1247
+ ) -> EvalLoopOutput:
1248
+ """
1249
+ Overriding built-in evaluation loop to store metrics for each batch.
1250
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1251
+
1252
+ Works both with or without labels.
1253
+ """
1254
+
1255
+ # Sample and save to game log if requested (for one batch to save time)
1256
+ if self.generate_during_eval:
1257
+ # Generate random indices within the range of the total number of samples
1258
+ num_samples = len(dataloader.dataset)
1259
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1260
+
1261
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1262
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1263
+ random_batch = self.data_collator(random_batch_dataset)
1264
+ random_batch = self._prepare_inputs(random_batch)
1265
+
1266
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1267
+
1268
+ table = pd.DataFrame(
1269
+ columns=["Prompt", "Policy"],
1270
+ data=[
1271
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1272
+ ],
1273
+ )
1274
+ if "wandb" in self.args.report_to:
1275
+ wandb.log({"game_log": wandb.Table(data=table)})
1276
+
1277
+ if "comet_ml" in self.args.report_to:
1278
+ log_table_to_comet_experiment(
1279
+ name="game_log.csv",
1280
+ table=table,
1281
+ )
1282
+
1283
+ # Base evaluation
1284
+ initial_output = super().evaluation_loop(
1285
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1286
+ )
1287
+
1288
+ return initial_output
1289
+
1290
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1291
+ """
1292
+ Log `logs` on the various objects watching training, including stored metrics.
1293
+
1294
+ Args:
1295
+ logs (`dict[str, float]`):
1296
+ The values to log.
1297
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1298
+ Start time of the training.
1299
+ """
1300
+ # logs either has 'loss' or 'eval_loss'
1301
+ train_eval = "train" if "loss" in logs else "eval"
1302
+ # Add averaged stored metrics to logs
1303
+ for key, metrics in self._stored_metrics[train_eval].items():
1304
+ logs[key] = torch.tensor(metrics).mean().item()
1305
+ del self._stored_metrics[train_eval]
1306
+
1307
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1308
+ return super().log(logs, start_time)
1309
+ else: # transformers<=4.46
1310
+ return super().log(logs)
1311
+
1312
+ def _shift_right(self, input_ids):
1313
+ if self.decoder_start_token_id is None:
1314
+ raise ValueError(
1315
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1316
+ )
1317
+
1318
+ # shift inputs to the right
1319
+ if is_torch_fx_proxy(input_ids):
1320
+ # Item assignment is not supported natively for proxies.
1321
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1322
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1323
+ else:
1324
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1325
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1326
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1327
+
1328
+ if self.pad_token_id is None:
1329
+ raise ValueError("model.config.pad_token_id has to be defined.")
1330
+ # replace possible -100 values in labels by `pad_token_id`
1331
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1332
+
1333
+ return shifted_input_ids
1334
+
1335
+ def create_model_card(
1336
+ self,
1337
+ model_name: Optional[str] = None,
1338
+ dataset_name: Optional[str] = None,
1339
+ tags: Union[str, list[str], None] = None,
1340
+ ):
1341
+ """
1342
+ Creates a draft of a model card using the information available to the `Trainer`.
1343
+
1344
+ Args:
1345
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1346
+ Name of the model.
1347
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1348
+ Name of the dataset used for training.
1349
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1350
+ Tags to be associated with the model card.
1351
+ """
1352
+ if not self.is_world_process_zero():
1353
+ return
1354
+
1355
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1356
+ base_model = self.model.config._name_or_path
1357
+ else:
1358
+ base_model = None
1359
+
1360
+ tags = tags or []
1361
+ if isinstance(tags, str):
1362
+ tags = [tags]
1363
+
1364
+ if hasattr(self.model.config, "unsloth_version"):
1365
+ tags.append("unsloth")
1366
+
1367
+ citation = textwrap.dedent("""\
1368
+ @inproceedings{xu2024contrastive,
1369
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1370
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1371
+ year = 2024,
1372
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1373
+ publisher = {OpenReview.net},
1374
+ url = {https://openreview.net/forum?id=51iwkioZpn}
1375
+ }""")
1376
+
1377
+ model_card = generate_model_card(
1378
+ base_model=base_model,
1379
+ model_name=model_name,
1380
+ hub_model_id=self.hub_model_id,
1381
+ dataset_name=dataset_name,
1382
+ tags=tags,
1383
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1384
+ comet_url=get_comet_experiment_url(),
1385
+ trainer_name="CPO",
1386
+ trainer_citation=citation,
1387
+ paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1388
+ paper_id="2401.08417",
1389
+ )
1390
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1391
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1392
+ """
1393
+
1394
+ Initialize CPOTrainer.
1395
+
1396
+ Args:
1397
+ model (`transformers.PreTrainedModel`):
1398
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1399
+ args (`CPOConfig`):
1400
+ The CPO config arguments to use for training.
1401
+ data_collator (`transformers.DataCollator`):
1402
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1403
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1404
+ train_dataset (`datasets.Dataset`):
1405
+ The dataset to use for training.
1406
+ eval_dataset (`datasets.Dataset`):
1407
+ The dataset to use for evaluation.
1408
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1409
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1410
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1411
+ reuse the fine-tuned model.
1412
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1413
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1414
+ callbacks (`list[transformers.TrainerCallback]`):
1415
+ The callbacks to use for training.
1416
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1417
+ The optimizer and scheduler to use for training.
1418
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1419
+ The function to use to preprocess the logits before computing the metrics.
1420
+ peft_config (`dict`, defaults to `None`):
1421
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1422
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1423
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1424
+ a dictionary string to metric values.
1425
+
1426
+ """
1427
+ def __init__(
1428
+ self,
1429
+ model = None,
1430
+ args = None,
1431
+ data_collator = None,
1432
+ train_dataset = None,
1433
+ eval_dataset = None,
1434
+ processing_class = None,
1435
+ model_init = None,
1436
+ callbacks = None,
1437
+ preprocess_logits_for_metrics = None,
1438
+ peft_config = None,
1439
+ compute_metrics = None,
1440
+ **kwargs
1441
+ ):
1442
+ if args is None: args = UnslothCPOConfig()
1443
+ use_bf16 = getattr(args, 'bf16', False)
1444
+ use_fp16 = getattr(args, 'fp16', False)
1445
+ force_float32 = False
1446
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1447
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1448
+ force_float32 = True
1449
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1450
+ dtype = getattr(model.config, 'torch_dtype', None)
1451
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1452
+ from unsloth_zoo.utils import _get_dtype
1453
+ dtype = _get_dtype(dtype)
1454
+ float16 = dtype == torch.float16
1455
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1456
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1457
+ if force_float32:
1458
+ args.fp16 = False
1459
+ args.bf16 = False
1460
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1461
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1462
+ args.fp16 = float16
1463
+ args.bf16 = not float16
1464
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1465
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1466
+ args.eval_strategy = 'steps'
1467
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1468
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1469
+ if ga_steps is not None and ga_steps > 1:
1470
+ from transformers import __version__ as transformers_version
1471
+ if Version(transformers_version) <= Version('4.45.2'):
1472
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1473
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1474
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1475
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1476
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1477
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1478
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1479
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1480
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1481
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1482
+ if force_float32:
1483
+ args.bf16_full_eval = False
1484
+ args.fp16_full_eval = False
1485
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1486
+ args.bf16_full_eval = True
1487
+ args.fp16_full_eval = False
1488
+ elif not bf16_full_eval and not fp16_full_eval:
1489
+ args.bf16_full_eval = args.bf16
1490
+ args.fp16_full_eval = args.fp16
1491
+ _output_logits = False
1492
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1493
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1494
+ if _output_logits:
1495
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1496
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1497
+ pass
1498
+ else:
1499
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1500
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1501
+ if args_max_seq_length is None and model_max_seq_length is not None:
1502
+ max_seq_length = model.max_seq_length
1503
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1504
+ if model is not None and hasattr(model, 'for_training'):
1505
+ model.for_training()
1506
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1507
+ if 'processing_class' in locals():
1508
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1509
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1510
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1511
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1512
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1513
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1514
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1515
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1516
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1517
+ else:
1518
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1519
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1520
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1521
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1522
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1523
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1524
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1525
+ else:
1526
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1527
+ other_metrics = []
1528
+
1529
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1530
+ PatchRLStatistics('cpo_trainer', other_metrics)
1531
+
1532
+ super().__init__(
1533
+ model = model,
1534
+ args = args,
1535
+ data_collator = data_collator,
1536
+ train_dataset = train_dataset,
1537
+ eval_dataset = eval_dataset,
1538
+ processing_class = processing_class,
1539
+ model_init = model_init,
1540
+ callbacks = callbacks,
1541
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1542
+ peft_config = peft_config,
1543
+ compute_metrics = compute_metrics,**kwargs)
1544
+ if hasattr(self, 'neftune_hook_handle'):
1545
+ self.neftune_hook_handle.remove()
1546
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1547
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1548
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1549
+ pass
1550
+
1551
+ pass
unsloth_compiled_cache/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothDDPOConfig(DDPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`DDPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (by default is the file name without the extension name).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed.
59
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
60
+ Log with either 'wandb' or 'tensorboard', check
61
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
62
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
63
+ Keyword arguments for the tracker (e.g. wandb_project).
64
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the accelerator.
66
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
68
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
69
+ Name of project to use for tracking.
70
+ logdir (`str`, *optional*, defaults to `"logs"`):
71
+ Top-level logging directory for checkpoint saving.
72
+ num_epochs (`int`, *optional*, defaults to `100`):
73
+ Number of epochs to train.
74
+ save_freq (`int`, *optional*, defaults to `1`):
75
+ Number of epochs between saving model checkpoints.
76
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
77
+ Number of checkpoints to keep before overwriting old ones.
78
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
79
+ Mixed precision training.
80
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
81
+ Allow `tf32` on Ampere GPUs.
82
+ resume_from (`str`, *optional*, defaults to `""`):
83
+ Resume training from a checkpoint.
84
+ sample_num_steps (`int`, *optional*, defaults to `50`):
85
+ Number of sampler inference steps.
86
+ sample_eta (`float`, *optional*, defaults to `1.0`):
87
+ Eta parameter for the DDIM sampler.
88
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
89
+ Classifier-free guidance weight.
90
+ sample_batch_size (`int`, *optional*, defaults to `1`):
91
+ Batch size (per GPU) to use for sampling.
92
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
93
+ Number of batches to sample per epoch.
94
+ train_batch_size (`int`, *optional*, defaults to `1`):
95
+ Batch size (per GPU) to use for training.
96
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
97
+ Use 8bit Adam optimizer from bitsandbytes.
98
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
99
+ Learning rate.
100
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
101
+ Adam beta1.
102
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
103
+ Adam beta2.
104
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
105
+ Adam weight decay.
106
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
107
+ Adam epsilon.
108
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
109
+ Number of gradient accumulation steps.
110
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
111
+ Maximum gradient norm for gradient clipping.
112
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
113
+ Number of inner epochs per outer epoch.
114
+ train_cfg (`bool`, *optional*, defaults to `True`):
115
+ Whether to use classifier-free guidance during training.
116
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
117
+ Clip advantages to the range.
118
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
119
+ PPO clip range.
120
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
121
+ Fraction of timesteps to train on.
122
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
123
+ Whether to track statistics for each prompt separately.
124
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
125
+ Number of reward values to store in the buffer for each prompt.
126
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
127
+ Minimum number of reward values to store in the buffer.
128
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
129
+ Whether to compute rewards asynchronously.
130
+ max_workers (`int`, *optional*, defaults to `2`):
131
+ Maximum number of workers to use for async reward computation.
132
+ negative_prompts (`str`, *optional*, defaults to `""`):
133
+ Comma-separated list of prompts to use as negative examples.
134
+ push_to_hub (`bool`, *optional*, defaults to `False`):
135
+ Whether to push the final model checkpoint to the Hub.
136
+
137
+ """
138
+ vllm_sampling_params: Optional[Any] = field(
139
+ default = None,
140
+ metadata = {'help': 'vLLM SamplingParams'},
141
+ )
142
+ unsloth_num_chunks : Optional[int] = field(
143
+ default = -1,
144
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
145
+ )
146
+ def __init__(
147
+ self,
148
+ exp_name = 'app',
149
+ run_name = '',
150
+ seed = 3407,
151
+ log_with = None,
152
+ tracker_project_name = 'trl',
153
+ logdir = 'logs',
154
+ num_epochs = 100,
155
+ save_freq = 1,
156
+ num_checkpoint_limit = 5,
157
+ mixed_precision = 'fp16',
158
+ allow_tf32 = True,
159
+ resume_from = '',
160
+ sample_num_steps = 50,
161
+ sample_eta = 1.0,
162
+ sample_guidance_scale = 5.0,
163
+ sample_batch_size = 1,
164
+ sample_num_batches_per_epoch = 2,
165
+ train_batch_size = 1,
166
+ train_use_8bit_adam = False,
167
+ train_learning_rate = 5e-05,
168
+ train_adam_beta1 = 0.9,
169
+ train_adam_beta2 = 0.999,
170
+ train_adam_weight_decay = 0.01,
171
+ train_adam_epsilon = 1e-08,
172
+ train_gradient_accumulation_steps = 2,
173
+ train_max_grad_norm = 1.0,
174
+ train_num_inner_epochs = 1,
175
+ train_cfg = True,
176
+ train_adv_clip_max = 5.0,
177
+ train_clip_range = 0.0001,
178
+ train_timestep_fraction = 1.0,
179
+ per_prompt_stat_tracking = False,
180
+ per_prompt_stat_tracking_buffer_size = 16,
181
+ per_prompt_stat_tracking_min_count = 16,
182
+ async_reward_computation = False,
183
+ max_workers = 2,
184
+ negative_prompts = '',
185
+ push_to_hub = False,
186
+ vllm_sampling_params = None,
187
+ unsloth_num_chunks = -1,
188
+ **kwargs,
189
+ ):
190
+
191
+ super().__init__(
192
+ exp_name = exp_name,
193
+ run_name = run_name,
194
+ seed = seed,
195
+ log_with = log_with,
196
+ tracker_project_name = tracker_project_name,
197
+ logdir = logdir,
198
+ num_epochs = num_epochs,
199
+ save_freq = save_freq,
200
+ num_checkpoint_limit = num_checkpoint_limit,
201
+ mixed_precision = mixed_precision,
202
+ allow_tf32 = allow_tf32,
203
+ resume_from = resume_from,
204
+ sample_num_steps = sample_num_steps,
205
+ sample_eta = sample_eta,
206
+ sample_guidance_scale = sample_guidance_scale,
207
+ sample_batch_size = sample_batch_size,
208
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
209
+ train_batch_size = train_batch_size,
210
+ train_use_8bit_adam = train_use_8bit_adam,
211
+ train_learning_rate = train_learning_rate,
212
+ train_adam_beta1 = train_adam_beta1,
213
+ train_adam_beta2 = train_adam_beta2,
214
+ train_adam_weight_decay = train_adam_weight_decay,
215
+ train_adam_epsilon = train_adam_epsilon,
216
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
217
+ train_max_grad_norm = train_max_grad_norm,
218
+ train_num_inner_epochs = train_num_inner_epochs,
219
+ train_cfg = train_cfg,
220
+ train_adv_clip_max = train_adv_clip_max,
221
+ train_clip_range = train_clip_range,
222
+ train_timestep_fraction = train_timestep_fraction,
223
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
224
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
225
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
226
+ async_reward_computation = async_reward_computation,
227
+ max_workers = max_workers,
228
+ negative_prompts = negative_prompts,
229
+ push_to_hub = push_to_hub,**kwargs)
230
+ self.vllm_sampling_params = vllm_sampling_params
231
+ self.unsloth_num_chunks = unsloth_num_chunks
232
+ pass
233
+
234
+ class _UnslothDDPOTrainer(PyTorchModelHubMixin):
235
+ """"""
236
+
237
+ _tag_names = ["trl", "ddpo"]
238
+
239
+ def __init__(
240
+ self,
241
+ config: DDPOConfig,
242
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
243
+ prompt_function: Callable[[], tuple[str, Any]],
244
+ sd_pipeline: DDPOStableDiffusionPipeline,
245
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
246
+ ):
247
+ if image_samples_hook is None:
248
+ warn("No image_samples_hook provided; no images will be logged")
249
+
250
+ self.prompt_fn = prompt_function
251
+ self.reward_fn = reward_function
252
+ self.config = config
253
+ self.image_samples_callback = image_samples_hook
254
+
255
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
256
+
257
+ if self.config.resume_from:
258
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
259
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
260
+ # get the most recent checkpoint in this directory
261
+ checkpoints = list(
262
+ filter(
263
+ lambda x: "checkpoint_" in x,
264
+ os.listdir(self.config.resume_from),
265
+ )
266
+ )
267
+ if len(checkpoints) == 0:
268
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
269
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
270
+ self.config.resume_from = os.path.join(
271
+ self.config.resume_from,
272
+ f"checkpoint_{checkpoint_numbers[-1]}",
273
+ )
274
+
275
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
276
+
277
+ # number of timesteps within each trajectory to train on
278
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
279
+
280
+ self.accelerator = Accelerator(
281
+ log_with=self.config.log_with,
282
+ mixed_precision=self.config.mixed_precision,
283
+ project_config=accelerator_project_config,
284
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
285
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
286
+ # the total number of optimizer steps to accumulate across.
287
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
288
+ **self.config.accelerator_kwargs,
289
+ )
290
+
291
+ is_okay, message = self._config_check()
292
+ if not is_okay:
293
+ raise ValueError(message)
294
+
295
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
296
+
297
+ if self.accelerator.is_main_process:
298
+ self.accelerator.init_trackers(
299
+ self.config.tracker_project_name,
300
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
301
+ init_kwargs=self.config.tracker_kwargs,
302
+ )
303
+
304
+ logger.info(f"\n{config}")
305
+
306
+ set_seed(self.config.seed, device_specific=True)
307
+
308
+ self.sd_pipeline = sd_pipeline
309
+
310
+ self.sd_pipeline.set_progress_bar_config(
311
+ position=1,
312
+ disable=not self.accelerator.is_local_main_process,
313
+ leave=False,
314
+ desc="Timestep",
315
+ dynamic_ncols=True,
316
+ )
317
+
318
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
319
+ # as these weights are only used for inference, keeping weights in full precision is not required.
320
+ if self.accelerator.mixed_precision == "fp16":
321
+ inference_dtype = torch.float16
322
+ elif self.accelerator.mixed_precision == "bf16":
323
+ inference_dtype = torch.bfloat16
324
+ else:
325
+ inference_dtype = torch.float32
326
+
327
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
328
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
329
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
330
+
331
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
332
+
333
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
334
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
335
+
336
+ # Enable TF32 for faster training on Ampere GPUs,
337
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
338
+ if self.config.allow_tf32:
339
+ torch.backends.cuda.matmul.allow_tf32 = True
340
+
341
+ self.optimizer = self._setup_optimizer(
342
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
343
+ )
344
+
345
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
346
+ self.sd_pipeline.tokenizer(
347
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
348
+ return_tensors="pt",
349
+ padding="max_length",
350
+ truncation=True,
351
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
352
+ ).input_ids.to(self.accelerator.device)
353
+ )[0]
354
+
355
+ if config.per_prompt_stat_tracking:
356
+ self.stat_tracker = PerPromptStatTracker(
357
+ config.per_prompt_stat_tracking_buffer_size,
358
+ config.per_prompt_stat_tracking_min_count,
359
+ )
360
+
361
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
362
+ # more memory
363
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
364
+
365
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
366
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
367
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
368
+ else:
369
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
370
+
371
+ if self.config.async_reward_computation:
372
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
373
+
374
+ if config.resume_from:
375
+ logger.info(f"Resuming from {config.resume_from}")
376
+ self.accelerator.load_state(config.resume_from)
377
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
378
+ else:
379
+ self.first_epoch = 0
380
+
381
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
382
+ if not is_async:
383
+ rewards = []
384
+ for images, prompts, prompt_metadata in prompt_image_pairs:
385
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
386
+ rewards.append(
387
+ (
388
+ torch.as_tensor(reward, device=self.accelerator.device),
389
+ reward_metadata,
390
+ )
391
+ )
392
+ else:
393
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
394
+ rewards = [
395
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
396
+ for reward, reward_metadata in rewards
397
+ ]
398
+
399
+ return zip(*rewards)
400
+
401
+ def step(self, epoch: int, global_step: int):
402
+ """
403
+ Perform a single step of training.
404
+
405
+ Args:
406
+ epoch (int): The current epoch.
407
+ global_step (int): The current global step.
408
+
409
+ Side Effects:
410
+ - Model weights are updated
411
+ - Logs the statistics to the accelerator trackers.
412
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
413
+
414
+ Returns:
415
+ global_step (int): The updated global step.
416
+
417
+ """
418
+ samples, prompt_image_data = self._generate_samples(
419
+ iterations=self.config.sample_num_batches_per_epoch,
420
+ batch_size=self.config.sample_batch_size,
421
+ )
422
+
423
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
424
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
425
+ rewards, rewards_metadata = self.compute_rewards(
426
+ prompt_image_data, is_async=self.config.async_reward_computation
427
+ )
428
+
429
+ for i, image_data in enumerate(prompt_image_data):
430
+ image_data.extend([rewards[i], rewards_metadata[i]])
431
+
432
+ if self.image_samples_callback is not None:
433
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
434
+
435
+ rewards = torch.cat(rewards)
436
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
437
+
438
+ self.accelerator.log(
439
+ {
440
+ "reward": rewards,
441
+ "epoch": epoch,
442
+ "reward_mean": rewards.mean(),
443
+ "reward_std": rewards.std(),
444
+ },
445
+ step=global_step,
446
+ )
447
+
448
+ if self.config.per_prompt_stat_tracking:
449
+ # gather the prompts across processes
450
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
451
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
452
+ advantages = self.stat_tracker.update(prompts, rewards)
453
+ else:
454
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
455
+
456
+ # ungather advantages; keep the entries corresponding to the samples on this process
457
+ samples["advantages"] = (
458
+ torch.as_tensor(advantages)
459
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
460
+ .to(self.accelerator.device)
461
+ )
462
+
463
+ del samples["prompt_ids"]
464
+
465
+ total_batch_size, num_timesteps = samples["timesteps"].shape
466
+
467
+ for inner_epoch in range(self.config.train_num_inner_epochs):
468
+ # shuffle samples along batch dimension
469
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
470
+ samples = {k: v[perm] for k, v in samples.items()}
471
+
472
+ # shuffle along time dimension independently for each sample
473
+ # still trying to understand the code below
474
+ perms = torch.stack(
475
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
476
+ )
477
+
478
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
479
+ samples[key] = samples[key][
480
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
481
+ perms,
482
+ ]
483
+
484
+ original_keys = samples.keys()
485
+ original_values = samples.values()
486
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
487
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
488
+
489
+ # Transpose the list of original values
490
+ transposed_values = zip(*reshaped_values)
491
+ # Create new dictionaries for each row of transposed values
492
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
493
+
494
+ self.sd_pipeline.unet.train()
495
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
496
+ # ensure optimization step at the end of the inner epoch
497
+ if not self.accelerator.sync_gradients:
498
+ raise ValueError(
499
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
500
+ )
501
+
502
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
503
+ self.accelerator.save_state()
504
+
505
+ return global_step
506
+
507
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
508
+ """
509
+ Calculate the loss for a batch of an unpacked sample
510
+
511
+ Args:
512
+ latents (torch.Tensor):
513
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
514
+ timesteps (torch.Tensor):
515
+ The timesteps sampled from the diffusion model, shape: [batch_size]
516
+ next_latents (torch.Tensor):
517
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
518
+ log_probs (torch.Tensor):
519
+ The log probabilities of the latents, shape: [batch_size]
520
+ advantages (torch.Tensor):
521
+ The advantages of the latents, shape: [batch_size]
522
+ embeds (torch.Tensor):
523
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
524
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
525
+
526
+ Returns:
527
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
528
+ (all of these are of shape (1,))
529
+ """
530
+ with self.autocast():
531
+ if self.config.train_cfg:
532
+ noise_pred = self.sd_pipeline.unet(
533
+ torch.cat([latents] * 2),
534
+ torch.cat([timesteps] * 2),
535
+ embeds,
536
+ ).sample
537
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
539
+ noise_pred_text - noise_pred_uncond
540
+ )
541
+ else:
542
+ noise_pred = self.sd_pipeline.unet(
543
+ latents,
544
+ timesteps,
545
+ embeds,
546
+ ).sample
547
+ # compute the log prob of next_latents given latents under the current model
548
+
549
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
550
+ noise_pred,
551
+ timesteps,
552
+ latents,
553
+ eta=self.config.sample_eta,
554
+ prev_sample=next_latents,
555
+ )
556
+
557
+ log_prob = scheduler_step_output.log_probs
558
+
559
+ advantages = torch.clamp(
560
+ advantages,
561
+ -self.config.train_adv_clip_max,
562
+ self.config.train_adv_clip_max,
563
+ )
564
+
565
+ ratio = torch.exp(log_prob - log_probs)
566
+
567
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
568
+
569
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
570
+
571
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
572
+
573
+ return loss, approx_kl, clipfrac
574
+
575
+ def loss(
576
+ self,
577
+ advantages: torch.Tensor,
578
+ clip_range: float,
579
+ ratio: torch.Tensor,
580
+ ):
581
+ unclipped_loss = -advantages * ratio
582
+ clipped_loss = -advantages * torch.clamp(
583
+ ratio,
584
+ 1.0 - clip_range,
585
+ 1.0 + clip_range,
586
+ )
587
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
588
+
589
+ def _setup_optimizer(self, trainable_layers_parameters):
590
+ if self.config.train_use_8bit_adam:
591
+ import bitsandbytes
592
+
593
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
594
+ else:
595
+ optimizer_cls = torch.optim.AdamW
596
+
597
+ return optimizer_cls(
598
+ trainable_layers_parameters,
599
+ lr=self.config.train_learning_rate,
600
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
601
+ weight_decay=self.config.train_adam_weight_decay,
602
+ eps=self.config.train_adam_epsilon,
603
+ )
604
+
605
+ def _save_model_hook(self, models, weights, output_dir):
606
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
607
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
608
+
609
+ def _load_model_hook(self, models, input_dir):
610
+ self.sd_pipeline.load_checkpoint(models, input_dir)
611
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
612
+
613
+ def _generate_samples(self, iterations, batch_size):
614
+ """
615
+ Generate samples from the model
616
+
617
+ Args:
618
+ iterations (int): Number of iterations to generate samples for
619
+ batch_size (int): Batch size to use for sampling
620
+
621
+ Returns:
622
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
623
+ """
624
+ samples = []
625
+ prompt_image_pairs = []
626
+ self.sd_pipeline.unet.eval()
627
+
628
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
629
+
630
+ for _ in range(iterations):
631
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
632
+
633
+ prompt_ids = self.sd_pipeline.tokenizer(
634
+ prompts,
635
+ return_tensors="pt",
636
+ padding="max_length",
637
+ truncation=True,
638
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
639
+ ).input_ids.to(self.accelerator.device)
640
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
641
+
642
+ with self.autocast():
643
+ sd_output = self.sd_pipeline(
644
+ prompt_embeds=prompt_embeds,
645
+ negative_prompt_embeds=sample_neg_prompt_embeds,
646
+ num_inference_steps=self.config.sample_num_steps,
647
+ guidance_scale=self.config.sample_guidance_scale,
648
+ eta=self.config.sample_eta,
649
+ output_type="pt",
650
+ )
651
+
652
+ images = sd_output.images
653
+ latents = sd_output.latents
654
+ log_probs = sd_output.log_probs
655
+
656
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
657
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
658
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
659
+
660
+ samples.append(
661
+ {
662
+ "prompt_ids": prompt_ids,
663
+ "prompt_embeds": prompt_embeds,
664
+ "timesteps": timesteps,
665
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
666
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
667
+ "log_probs": log_probs,
668
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
669
+ }
670
+ )
671
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
672
+
673
+ return samples, prompt_image_pairs
674
+
675
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
676
+ """
677
+ Train on a batch of samples. Main training segment
678
+
679
+ Args:
680
+ inner_epoch (int): The current inner epoch
681
+ epoch (int): The current epoch
682
+ global_step (int): The current global step
683
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
684
+
685
+ Side Effects:
686
+ - Model weights are updated
687
+ - Logs the statistics to the accelerator trackers.
688
+
689
+ Returns:
690
+ global_step (int): The updated global step
691
+ """
692
+ info = defaultdict(list)
693
+ for _i, sample in enumerate(batched_samples):
694
+ if self.config.train_cfg:
695
+ # concat negative prompts to sample prompts to avoid two forward passes
696
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
697
+ else:
698
+ embeds = sample["prompt_embeds"]
699
+
700
+ for j in range(self.num_train_timesteps):
701
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
702
+ loss, approx_kl, clipfrac = self.calculate_loss(
703
+ sample["latents"][:, j],
704
+ sample["timesteps"][:, j],
705
+ sample["next_latents"][:, j],
706
+ sample["log_probs"][:, j],
707
+ sample["advantages"],
708
+ embeds,
709
+ )
710
+ info["approx_kl"].append(approx_kl)
711
+ info["clipfrac"].append(clipfrac)
712
+ info["loss"].append(loss)
713
+
714
+ self.accelerator.backward(loss)
715
+ if self.accelerator.sync_gradients:
716
+ self.accelerator.clip_grad_norm_(
717
+ self.trainable_layers.parameters()
718
+ if not isinstance(self.trainable_layers, list)
719
+ else self.trainable_layers,
720
+ self.config.train_max_grad_norm,
721
+ )
722
+ self.optimizer.step()
723
+ self.optimizer.zero_grad()
724
+
725
+ # Checks if the accelerator has performed an optimization step behind the scenes
726
+ if self.accelerator.sync_gradients:
727
+ # log training-related stuff
728
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
729
+ info = self.accelerator.reduce(info, reduction="mean")
730
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
731
+ self.accelerator.log(info, step=global_step)
732
+ global_step += 1
733
+ info = defaultdict(list)
734
+ return global_step
735
+
736
+ def _config_check(self) -> tuple[bool, str]:
737
+ samples_per_epoch = (
738
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
739
+ )
740
+ total_train_batch_size = (
741
+ self.config.train_batch_size
742
+ * self.accelerator.num_processes
743
+ * self.config.train_gradient_accumulation_steps
744
+ )
745
+
746
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
747
+ return (
748
+ False,
749
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
750
+ )
751
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
752
+ return (
753
+ False,
754
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
755
+ )
756
+ if not samples_per_epoch % total_train_batch_size == 0:
757
+ return (
758
+ False,
759
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
760
+ )
761
+ return True, ""
762
+
763
+ def train(self, epochs: Optional[int] = None):
764
+ """
765
+ Train the model for a given number of epochs
766
+ """
767
+ global_step = 0
768
+ if epochs is None:
769
+ epochs = self.config.num_epochs
770
+ for epoch in range(self.first_epoch, epochs):
771
+ global_step = self.step(epoch, global_step)
772
+
773
+ def _save_pretrained(self, save_directory):
774
+ self.sd_pipeline.save_pretrained(save_directory)
775
+ self.create_model_card()
776
+
777
+ def create_model_card(
778
+ self,
779
+ model_name: Optional[str] = None,
780
+ dataset_name: Optional[str] = None,
781
+ tags: Union[str, list[str], None] = None,
782
+ ):
783
+ """
784
+ Creates a draft of a model card using the information available to the `Trainer`.
785
+
786
+ Args:
787
+ model_name (`str` or `None`, *optional*, defaults to `None`):
788
+ Name of the model.
789
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
+ Name of the dataset used for training.
791
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
+ Tags to be associated with the model card.
793
+ """
794
+ if not self.is_world_process_zero():
795
+ return
796
+
797
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
+ base_model = self.model.config._name_or_path
799
+ else:
800
+ base_model = None
801
+
802
+ tags = tags or []
803
+ if isinstance(tags, str):
804
+ tags = [tags]
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.append("unsloth")
808
+
809
+ citation = textwrap.dedent("""\
810
+ @inproceedings{black2024training,
811
+ title = {{Training Diffusion Models with Reinforcement Learning}},
812
+ author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
813
+ year = 2024,
814
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
815
+ publisher = {OpenReview.net},
816
+ url = {https://openreview.net/forum?id=YCWjhGrJFD},
817
+ }""")
818
+
819
+ model_card = generate_model_card(
820
+ base_model=base_model,
821
+ model_name=model_name,
822
+ hub_model_id=self.hub_model_id,
823
+ dataset_name=dataset_name,
824
+ tags=tags,
825
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
826
+ comet_url=get_comet_experiment_url(),
827
+ trainer_name="DDPO",
828
+ trainer_citation=citation,
829
+ paper_title="Training Diffusion Models with Reinforcement Learning",
830
+ paper_id="2305.13301",
831
+ )
832
+
833
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
834
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
835
+ """
836
+
837
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
838
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
839
+ As of now only Stable Diffusion based pipelines are supported
840
+
841
+ Attributes:
842
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
843
+ details.
844
+ **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
845
+ **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
846
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
847
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
848
+
849
+ """
850
+ def __init__(
851
+ self,
852
+ config,
853
+ reward_function,
854
+ prompt_function,
855
+ sd_pipeline,
856
+ image_samples_hook = None,
857
+ **kwargs
858
+ ):
859
+ if args is None: args = UnslothDDPOConfig()
860
+ other_metrics = []
861
+
862
+ from unsloth_zoo.logging_utils import PatchRLStatistics
863
+ PatchRLStatistics('ddpo_trainer', other_metrics)
864
+
865
+ super().__init__(
866
+ config = config,
867
+ reward_function = reward_function,
868
+ prompt_function = prompt_function,
869
+ sd_pipeline = sd_pipeline,
870
+ image_samples_hook = image_samples_hook,**kwargs)
871
+
872
+ pass
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothGKDConfig(GKDConfig):
44
+ """
45
+
46
+ Configuration class for [`GKDTrainer`].
47
+
48
+ Args:
49
+ temperature (`float`, *optional*, defaults to `0.9`):
50
+ Temperature for sampling. The higher the temperature, the more random the completions.
51
+ lmbda (`float`, *optional*, defaults to `0.5`):
52
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
53
+ student-generated outputs).
54
+ beta (`float`, *optional*, defaults to `0.5`):
55
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
56
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
57
+ max_new_tokens (`int`, *optional*, defaults to `128`):
58
+ Maximum number of tokens to generate per completion.
59
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
60
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
61
+ being trained.
62
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
63
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
64
+ from a string.
65
+ disable_dropout (`bool`, *optional*, defaults to `True`):
66
+ Whether to disable dropout in the model.
67
+ seq_kd (`bool`, *optional*, defaults to `False`):
68
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
69
+ on teacher-generated output).
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ push_to_hub_model_id = None,
189
+ push_to_hub_organization = None,
190
+ push_to_hub_token = None,
191
+ mp_parameters = '',
192
+ auto_find_batch_size = False,
193
+ full_determinism = False,
194
+ torchdynamo = None,
195
+ ray_scope = 'last',
196
+ ddp_timeout = 1800,
197
+ torch_compile = False,
198
+ torch_compile_backend = None,
199
+ torch_compile_mode = None,
200
+ include_tokens_per_second = False,
201
+ include_num_input_tokens_seen = False,
202
+ neftune_noise_alpha = None,
203
+ optim_target_modules = None,
204
+ batch_eval_metrics = False,
205
+ eval_on_start = False,
206
+ use_liger_kernel = False,
207
+ eval_use_gather_object = False,
208
+ average_tokens_across_devices = False,
209
+ model_init_kwargs = None,
210
+ use_liger = False,
211
+ dataset_text_field = 'text',
212
+ dataset_kwargs = None,
213
+ dataset_num_proc = None,
214
+ max_seq_length = None,
215
+ packing = False,
216
+ eval_packing = None,
217
+ dataset_batch_size = None,
218
+ num_of_sequences = None,
219
+ chars_per_token = None,
220
+ temperature = 0.9,
221
+ lmbda = 0.5,
222
+ beta = 0.5,
223
+ max_new_tokens = 128,
224
+ teacher_model_name_or_path = None,
225
+ teacher_model_init_kwargs = None,
226
+ disable_dropout = True,
227
+ seq_kd = False,
228
+ vllm_sampling_params = None,
229
+ unsloth_num_chunks = -1,
230
+ **kwargs,
231
+ ):
232
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
233
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
234
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
235
+ output_dir = 'unsloth_training_checkpoints'
236
+ save_strategy = 'no'
237
+ if dataset_num_proc is None:
238
+ from multiprocessing import cpu_count
239
+ dataset_num_proc = cpu_count()
240
+
241
+ super().__init__(
242
+ output_dir = output_dir,
243
+ overwrite_output_dir = overwrite_output_dir,
244
+ do_train = do_train,
245
+ do_eval = do_eval,
246
+ do_predict = do_predict,
247
+ eval_strategy = eval_strategy,
248
+ prediction_loss_only = prediction_loss_only,
249
+ per_device_train_batch_size = per_device_train_batch_size,
250
+ per_device_eval_batch_size = per_device_eval_batch_size,
251
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
252
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
253
+ gradient_accumulation_steps = gradient_accumulation_steps,
254
+ eval_accumulation_steps = eval_accumulation_steps,
255
+ eval_delay = eval_delay,
256
+ torch_empty_cache_steps = torch_empty_cache_steps,
257
+ learning_rate = learning_rate,
258
+ weight_decay = weight_decay,
259
+ adam_beta1 = adam_beta1,
260
+ adam_beta2 = adam_beta2,
261
+ adam_epsilon = adam_epsilon,
262
+ max_grad_norm = max_grad_norm,
263
+ num_train_epochs = num_train_epochs,
264
+ max_steps = max_steps,
265
+ lr_scheduler_type = lr_scheduler_type,
266
+ warmup_ratio = warmup_ratio,
267
+ warmup_steps = warmup_steps,
268
+ log_level = log_level,
269
+ log_level_replica = log_level_replica,
270
+ log_on_each_node = log_on_each_node,
271
+ logging_dir = logging_dir,
272
+ logging_strategy = logging_strategy,
273
+ logging_first_step = logging_first_step,
274
+ logging_steps = logging_steps,
275
+ logging_nan_inf_filter = logging_nan_inf_filter,
276
+ save_strategy = save_strategy,
277
+ save_steps = save_steps,
278
+ save_total_limit = save_total_limit,
279
+ save_safetensors = save_safetensors,
280
+ save_on_each_node = save_on_each_node,
281
+ save_only_model = save_only_model,
282
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
283
+ no_cuda = no_cuda,
284
+ use_cpu = use_cpu,
285
+ use_mps_device = use_mps_device,
286
+ seed = seed,
287
+ data_seed = data_seed,
288
+ jit_mode_eval = jit_mode_eval,
289
+ use_ipex = use_ipex,
290
+ bf16 = bf16,
291
+ fp16 = fp16,
292
+ fp16_opt_level = fp16_opt_level,
293
+ half_precision_backend = half_precision_backend,
294
+ bf16_full_eval = bf16_full_eval,
295
+ fp16_full_eval = fp16_full_eval,
296
+ tf32 = tf32,
297
+ local_rank = local_rank,
298
+ ddp_backend = ddp_backend,
299
+ tpu_num_cores = tpu_num_cores,
300
+ tpu_metrics_debug = tpu_metrics_debug,
301
+ debug = debug,
302
+ dataloader_drop_last = dataloader_drop_last,
303
+ eval_steps = eval_steps,
304
+ dataloader_num_workers = dataloader_num_workers,
305
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
306
+ past_index = past_index,
307
+ run_name = run_name,
308
+ disable_tqdm = disable_tqdm,
309
+ remove_unused_columns = remove_unused_columns,
310
+ label_names = label_names,
311
+ load_best_model_at_end = load_best_model_at_end,
312
+ metric_for_best_model = metric_for_best_model,
313
+ greater_is_better = greater_is_better,
314
+ ignore_data_skip = ignore_data_skip,
315
+ fsdp = fsdp,
316
+ fsdp_min_num_params = fsdp_min_num_params,
317
+ fsdp_config = fsdp_config,
318
+ tp_size = tp_size,
319
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
320
+ accelerator_config = accelerator_config,
321
+ deepspeed = deepspeed,
322
+ label_smoothing_factor = label_smoothing_factor,
323
+ optim = optim,
324
+ optim_args = optim_args,
325
+ adafactor = adafactor,
326
+ group_by_length = group_by_length,
327
+ length_column_name = length_column_name,
328
+ report_to = report_to,
329
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
330
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
331
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
332
+ dataloader_pin_memory = dataloader_pin_memory,
333
+ dataloader_persistent_workers = dataloader_persistent_workers,
334
+ skip_memory_metrics = skip_memory_metrics,
335
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
336
+ push_to_hub = push_to_hub,
337
+ resume_from_checkpoint = resume_from_checkpoint,
338
+ hub_model_id = hub_model_id,
339
+ hub_strategy = hub_strategy,
340
+ hub_token = hub_token,
341
+ hub_private_repo = hub_private_repo,
342
+ hub_always_push = hub_always_push,
343
+ gradient_checkpointing = gradient_checkpointing,
344
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
345
+ include_inputs_for_metrics = include_inputs_for_metrics,
346
+ eval_do_concat_batches = eval_do_concat_batches,
347
+ fp16_backend = fp16_backend,
348
+ push_to_hub_model_id = push_to_hub_model_id,
349
+ push_to_hub_organization = push_to_hub_organization,
350
+ push_to_hub_token = push_to_hub_token,
351
+ mp_parameters = mp_parameters,
352
+ auto_find_batch_size = auto_find_batch_size,
353
+ full_determinism = full_determinism,
354
+ torchdynamo = torchdynamo,
355
+ ray_scope = ray_scope,
356
+ ddp_timeout = ddp_timeout,
357
+ torch_compile = torch_compile,
358
+ torch_compile_backend = torch_compile_backend,
359
+ torch_compile_mode = torch_compile_mode,
360
+ include_tokens_per_second = include_tokens_per_second,
361
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
362
+ neftune_noise_alpha = neftune_noise_alpha,
363
+ optim_target_modules = optim_target_modules,
364
+ batch_eval_metrics = batch_eval_metrics,
365
+ eval_on_start = eval_on_start,
366
+ use_liger_kernel = use_liger_kernel,
367
+ eval_use_gather_object = eval_use_gather_object,
368
+ average_tokens_across_devices = average_tokens_across_devices,
369
+ model_init_kwargs = model_init_kwargs,
370
+ use_liger = use_liger,
371
+ dataset_text_field = dataset_text_field,
372
+ dataset_kwargs = dataset_kwargs,
373
+ dataset_num_proc = dataset_num_proc,
374
+ max_seq_length = max_seq_length,
375
+ packing = packing,
376
+ eval_packing = eval_packing,
377
+ dataset_batch_size = dataset_batch_size,
378
+ num_of_sequences = num_of_sequences,
379
+ chars_per_token = chars_per_token,
380
+ temperature = temperature,
381
+ lmbda = lmbda,
382
+ beta = beta,
383
+ max_new_tokens = max_new_tokens,
384
+ teacher_model_name_or_path = teacher_model_name_or_path,
385
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
386
+ disable_dropout = disable_dropout,
387
+ seq_kd = seq_kd,**kwargs)
388
+ self.vllm_sampling_params = vllm_sampling_params
389
+ self.unsloth_num_chunks = unsloth_num_chunks
390
+ pass
391
+
392
+ class _UnslothGKDTrainer(SFTTrainer):
393
+ _tag_names = ["trl", "gkd"]
394
+
395
+ def __init__(
396
+ self,
397
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
398
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
399
+ args: Optional[GKDConfig] = None,
400
+ data_collator: Optional[DataCollator] = None, # type: ignore
401
+ train_dataset: Optional[Dataset] = None,
402
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
403
+ processing_class: Optional[
404
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
405
+ ] = None,
406
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
407
+ callbacks: Optional[list[TrainerCallback]] = None,
408
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
409
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
410
+ peft_config: Optional["PeftConfig"] = None,
411
+ formatting_func: Optional[Callable] = None,
412
+ ):
413
+ # add remove_unused_columns=False to the dataclass args
414
+ args.remove_unused_columns = False
415
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
416
+
417
+ super().__init__(
418
+ model,
419
+ args=args,
420
+ data_collator=data_collator,
421
+ train_dataset=train_dataset,
422
+ eval_dataset=eval_dataset,
423
+ processing_class=processing_class,
424
+ compute_metrics=compute_metrics,
425
+ callbacks=callbacks,
426
+ optimizers=optimizers,
427
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
428
+ peft_config=peft_config,
429
+ formatting_func=formatting_func,
430
+ )
431
+
432
+ if args.teacher_model_init_kwargs is None:
433
+ teacher_model_init_kwargs = {}
434
+ elif not isinstance(teacher_model, str):
435
+ raise ValueError(
436
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
437
+ )
438
+ else:
439
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
440
+ teacher_model_init_kwargs["torch_dtype"] = (
441
+ teacher_model_init_kwargs["torch_dtype"]
442
+ if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
443
+ else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
444
+ )
445
+
446
+ if isinstance(teacher_model, str):
447
+ if args.use_liger:
448
+ teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
449
+ else:
450
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
451
+
452
+ # Disable dropout in the model
453
+ if args.disable_dropout:
454
+ disable_dropout_in_model(self.model)
455
+
456
+ if self.is_deepspeed_enabled:
457
+ self.teacher_model = self._prepare_deepspeed(teacher_model)
458
+ else:
459
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
460
+
461
+ self.lmbda = args.lmbda
462
+ self.beta = args.beta
463
+ self.temperature = args.temperature
464
+ self.seq_kd = args.seq_kd
465
+
466
+ self.generation_config = GenerationConfig(
467
+ max_new_tokens=args.max_new_tokens,
468
+ temperature=args.temperature,
469
+ do_sample=True,
470
+ top_k=0,
471
+ use_cache=False if args.gradient_checkpointing else True,
472
+ pad_token_id=self.processing_class.pad_token_id,
473
+ )
474
+ # Set custom EOS tokens if they are specified by the model's generation
475
+ # config. This is important for models with the Llama 3 chat template,
476
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
477
+ # turns or messages.
478
+ if (
479
+ hasattr(self.model.generation_config, "eos_token_id")
480
+ and self.model.generation_config.eos_token_id is not None
481
+ ):
482
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
483
+
484
+ def _prepare_dataset(self, dataset, *args):
485
+ # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
486
+ # need to keep the messages column as it is. We use the following workaround to keep the messages column.
487
+ dataset = dataset.add_column("_messages", dataset["messages"])
488
+ dataset = super()._prepare_dataset(dataset, *args)
489
+ dataset = dataset.rename_column("_messages", "messages")
490
+ return dataset
491
+
492
+ @staticmethod
493
+ def generalized_jsd_loss(
494
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
495
+ ):
496
+ """
497
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
498
+ of https://huggingface.co/papers/2306.13649 for the definition.
499
+
500
+ Args:
501
+ student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
502
+ teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
503
+ labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
504
+ beta: Interpolation coefficient between 0 and 1 (default: 0.5)
505
+ temperature: Softmax temperature (default: 1.0)
506
+ reduction: Specifies the reduction to apply to the output (default: 'batchmean')
507
+
508
+ Returns:
509
+ loss: Scalar tensor with the generalized JSD loss
510
+ """
511
+
512
+ # Apply temperature scaling
513
+ student_logits = student_logits / temperature
514
+ teacher_logits = teacher_logits / temperature
515
+
516
+ # Compute log probabilities for student and probabilities for teacher
517
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
518
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
519
+
520
+ # Compute the log of the mixture distribution
521
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
522
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
523
+ mixture_log_probs = torch.logsumexp(
524
+ torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
525
+ dim=0,
526
+ )
527
+
528
+ # Compute KL divergences using F.kl_div
529
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
530
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
531
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
532
+
533
+ # Compute the Generalized Jensen-Shannon Divergence
534
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
535
+
536
+ # Masking
537
+ if labels is not None:
538
+ mask = labels != -100
539
+ jsd = jsd[mask]
540
+
541
+ # Apply reduction
542
+ if reduction == "batchmean":
543
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
544
+ elif reduction == "sum":
545
+ return jsd.sum()
546
+ elif reduction == "mean":
547
+ return jsd.mean()
548
+ else:
549
+ return jsd
550
+
551
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
552
+ # compute student output
553
+ outputs_student = model(
554
+ input_ids=inputs["input_ids"],
555
+ attention_mask=inputs["attention_mask"],
556
+ )
557
+
558
+ # compute teacher output in eval mode
559
+ self.teacher_model.eval()
560
+ with torch.no_grad():
561
+ outputs_teacher = self.teacher_model(
562
+ input_ids=inputs["input_ids"],
563
+ attention_mask=inputs["attention_mask"],
564
+ )
565
+
566
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
567
+ prompt_lengths = inputs["prompts"].shape[1]
568
+ shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
569
+ shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
570
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
571
+
572
+ # compute loss
573
+ loss = self.generalized_jsd_loss(
574
+ student_logits=shifted_student_logits,
575
+ teacher_logits=shifted_teacher_logits,
576
+ labels=shifted_labels,
577
+ beta=self.beta,
578
+ )
579
+
580
+ # empty cache
581
+ empty_cache()
582
+
583
+ # Return loss
584
+ return (loss, outputs_student) if return_outputs else loss
585
+
586
+ @staticmethod
587
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
588
+ # Generate output with respect to the prompt only
589
+ generated_outputs = model.generate(
590
+ input_ids=inputs["prompts"],
591
+ attention_mask=inputs.get("prompt_attention_mask", None),
592
+ generation_config=generation_config,
593
+ return_dict_in_generate=True,
594
+ )
595
+
596
+ # Get the generated token IDs
597
+ generated_tokens = generated_outputs.sequences
598
+ # Calculate new attention mask
599
+ new_attention_mask = torch.ones_like(generated_tokens)
600
+ new_labels = generated_tokens.clone()
601
+
602
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
603
+ if pad_token_id is not None:
604
+ new_labels[new_labels == pad_token_id] = -100
605
+ new_attention_mask[generated_tokens == pad_token_id] = 0
606
+
607
+ return generated_tokens, new_attention_mask, new_labels
608
+
609
+ def training_step(
610
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
611
+ ) -> torch.Tensor:
612
+ """
613
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
614
+
615
+ This method implements the on-policy learning approach described in the GKD paper.
616
+ With probability `self.lmbda`, it generates new responses using the student model,
617
+ which are then used for training instead of the original inputs.
618
+ """
619
+ if self.seq_kd:
620
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
621
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
622
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
623
+ )
624
+ inputs["input_ids"] = new_input_ids
625
+ inputs["attention_mask"] = new_attention_mask
626
+ inputs["labels"] = new_labels
627
+ if random.random() <= self.lmbda:
628
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
629
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
630
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
631
+ )
632
+ inputs["input_ids"] = new_input_ids
633
+ inputs["attention_mask"] = new_attention_mask
634
+ inputs["labels"] = new_labels
635
+
636
+ loss = super().training_step(model, inputs, num_items_in_batch)
637
+ return loss
638
+
639
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
640
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
641
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
642
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
643
+
644
+ if model is not None:
645
+ if hasattr(model, "config"):
646
+ hidden_size = (
647
+ max(model.config.hidden_sizes)
648
+ if getattr(model.config, "hidden_sizes", None)
649
+ else getattr(model.config, "hidden_size", None)
650
+ )
651
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
652
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
653
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
654
+ config_kwargs.update(
655
+ {
656
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
657
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
658
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
659
+ }
660
+ )
661
+
662
+ # If ZeRO-3 is used, we shard both the active and reference model.
663
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
664
+ if config_kwargs["zero_optimization"]["stage"] != 3:
665
+ config_kwargs["zero_optimization"]["stage"] = 0
666
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
667
+ model.eval()
668
+ return model
669
+
670
+ def create_model_card(
671
+ self,
672
+ model_name: Optional[str] = None,
673
+ dataset_name: Optional[str] = None,
674
+ tags: Union[str, list[str], None] = None,
675
+ ):
676
+ """
677
+ Creates a draft of a model card using the information available to the `Trainer`.
678
+
679
+ Args:
680
+ model_name (`str` or `None`, *optional*, defaults to `None`):
681
+ Name of the model.
682
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
683
+ Name of the dataset used for training.
684
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
685
+ Tags to be associated with the model card.
686
+ """
687
+ if not self.is_world_process_zero():
688
+ return
689
+
690
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
691
+ base_model = self.model.config._name_or_path
692
+ else:
693
+ base_model = None
694
+
695
+ tags = tags or []
696
+ if isinstance(tags, str):
697
+ tags = [tags]
698
+
699
+ if hasattr(self.model.config, "unsloth_version"):
700
+ tags.append("unsloth")
701
+
702
+ citation = textwrap.dedent("""\
703
+ @inproceedings{agarwal2024on-policy,
704
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
705
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
706
+ year = 2024,
707
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
708
+ publisher = {OpenReview.net},
709
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
710
+ }""")
711
+
712
+ model_card = generate_model_card(
713
+ base_model=base_model,
714
+ model_name=model_name,
715
+ hub_model_id=self.hub_model_id,
716
+ dataset_name=dataset_name,
717
+ tags=tags,
718
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
719
+ comet_url=get_comet_experiment_url(),
720
+ trainer_name="GKD",
721
+ trainer_citation=citation,
722
+ paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
723
+ paper_id="2306.13649",
724
+ )
725
+
726
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
727
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
728
+ """
729
+
730
+ """
731
+ def __init__(
732
+ self,
733
+ model = None,
734
+ teacher_model = None,
735
+ args = None,
736
+ data_collator = None,
737
+ train_dataset = None,
738
+ eval_dataset = None,
739
+ processing_class = None,
740
+ compute_metrics = None,
741
+ callbacks = None,
742
+ preprocess_logits_for_metrics = None,
743
+ peft_config = None,
744
+ formatting_func = None,
745
+ **kwargs
746
+ ):
747
+ if args is None: args = UnslothGKDConfig()
748
+ use_bf16 = getattr(args, 'bf16', False)
749
+ use_fp16 = getattr(args, 'fp16', False)
750
+ force_float32 = False
751
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
752
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
753
+ force_float32 = True
754
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
755
+ dtype = getattr(model.config, 'torch_dtype', None)
756
+ if dtype is None: dtype = model.get_input_embeddings().dtype
757
+ from unsloth_zoo.utils import _get_dtype
758
+ dtype = _get_dtype(dtype)
759
+ float16 = dtype == torch.float16
760
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
761
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
762
+ if force_float32:
763
+ args.fp16 = False
764
+ args.bf16 = False
765
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
766
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
767
+ args.fp16 = float16
768
+ args.bf16 = not float16
769
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
770
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
771
+ args.eval_strategy = 'steps'
772
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
773
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
774
+ if ga_steps is not None and ga_steps > 1:
775
+ from transformers import __version__ as transformers_version
776
+ if Version(transformers_version) <= Version('4.45.2'):
777
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
778
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
779
+ if getattr(args, 'eval_strategy', 'no') != 'no':
780
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
781
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
782
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
783
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
784
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
785
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
786
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
787
+ if force_float32:
788
+ args.bf16_full_eval = False
789
+ args.fp16_full_eval = False
790
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
791
+ args.bf16_full_eval = True
792
+ args.fp16_full_eval = False
793
+ elif not bf16_full_eval and not fp16_full_eval:
794
+ args.bf16_full_eval = args.bf16
795
+ args.fp16_full_eval = args.fp16
796
+ _output_logits = False
797
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
798
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
799
+ if _output_logits:
800
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
801
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
802
+ pass
803
+ else:
804
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
805
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
806
+ if args_max_seq_length is None and model_max_seq_length is not None:
807
+ max_seq_length = model.max_seq_length
808
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
809
+ if model is not None and hasattr(model, 'for_training'):
810
+ model.for_training()
811
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
812
+ if 'processing_class' in locals():
813
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
814
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
815
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
816
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
817
+ if not isinstance(data_collator, UnslothVisionDataCollator):
818
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
819
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
820
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
821
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
822
+ else:
823
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
824
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
825
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
826
+ if not isinstance(data_collator, UnslothVisionDataCollator):
827
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
828
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
829
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
830
+ else:
831
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
832
+ other_metrics = []
833
+
834
+ from unsloth_zoo.logging_utils import PatchRLStatistics
835
+ PatchRLStatistics('gkd_trainer', other_metrics)
836
+
837
+ super().__init__(
838
+ model = model,
839
+ teacher_model = teacher_model,
840
+ args = args,
841
+ data_collator = data_collator,
842
+ train_dataset = train_dataset,
843
+ eval_dataset = eval_dataset,
844
+ processing_class = processing_class,
845
+ compute_metrics = compute_metrics,
846
+ callbacks = callbacks,
847
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
848
+ peft_config = peft_config,
849
+ formatting_func = formatting_func,**kwargs)
850
+ if hasattr(self, 'neftune_hook_handle'):
851
+ self.neftune_hook_handle.remove()
852
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
853
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
854
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
855
+ pass
856
+
857
+ pass
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
@@ -0,0 +1,1432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+
43
+ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
44
+ # All Unsloth Zoo code licensed under LGPLv3
45
+ old_logits = old_logits.to(torch.float32)
46
+ new_logits = new_logits.to(torch.float32)
47
+ input_ids = input_ids.unsqueeze(-1)
48
+
49
+ # x_i - logsumexp(x_i)
50
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
51
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
52
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
53
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
54
+
55
+ # Reverse KL
56
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
57
+ # Full correct reverse KL divergence?? Missing term maybe?
58
+ # kl_i = torch.exp(new) * kl_i
59
+
60
+ # Below is forward KL (normal KL)
61
+ # kl_i = torch.exp(old) * (old - new)
62
+
63
+ # Must detach - otherwise gradients are not propagated correctly!
64
+ # exp(x - x) == 1
65
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
66
+ loss_i = -(loss_i - beta * kl_i)
67
+
68
+ mask = mask.to(torch.float32)
69
+ n_mask_per_reward = mask.sum(1)
70
+
71
+ # See https://github.com/huggingface/trl/pull/2881
72
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
73
+ loss = loss_per_reward.mean()
74
+ # loss = (loss_i * mask).sum() / mask.sum()
75
+
76
+ # Get metrics as well which are folded
77
+ with torch.inference_mode():
78
+ completion_length = n_mask_per_reward.mean()
79
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
80
+ mean_kl = mean_kl_per_reward.mean()
81
+ pass
82
+ return loss, completion_length, mean_kl
83
+
84
+ class UnslothEfficientGRPO(torch.autograd.Function):
85
+ # All Unsloth Zoo code licensed under LGPLv3
86
+ @staticmethod
87
+ def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
88
+ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
89
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
90
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
91
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
92
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
93
+ loss, completion_length, mean_kl = grpo_compute_loss(
94
+ old_logits, new_logits, input_ids, mask, beta, advantages,
95
+ )
96
+ # Scale loss if needed for mixed precision training
97
+ scaled_loss = loss * scaling
98
+ # Must add .loss.detach otherwise autograd uses 2x VRAM
99
+ return scaled_loss, (loss.detach(), completion_length, mean_kl,)
100
+ pass
101
+
102
+ device =_new_hidden_states.device
103
+ grad_inputs = torch.empty_like(_new_hidden_states)
104
+ accumulated_loss = torch.zeros(1, device = device)
105
+ accumulated_completion_length = torch.zeros(1, device = device)
106
+ accumulated_mean_kl = torch.zeros(1, device = device)
107
+
108
+ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
109
+ (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
110
+ compute_loss,
111
+ argnums = (0,),
112
+ has_aux = True,
113
+ )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
114
+ accumulated_loss .add_(unscaled_loss)
115
+ accumulated_completion_length.add_(chunk_completion_length)
116
+ accumulated_mean_kl .add_(chunk_mean_kl)
117
+ return chunk_grad_input
118
+ pass
119
+
120
+ accumulate_chunk = torch.compile(
121
+ accumulate_chunk,
122
+ fullgraph = True,
123
+ options = torch_compile_options,
124
+ )
125
+
126
+ grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
127
+ new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
128
+ old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
129
+ input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
130
+ mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
131
+ advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
132
+
133
+ # Get mixed precision scaling if seen
134
+ scaling = scaler.get_scale() if scaler is not None else 1.0
135
+
136
+ # Force torch.compile to use dynamic shapes for seqlen dim
137
+ mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
138
+
139
+ for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
140
+ zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
141
+
142
+ mark_dynamic(new_hidden_states_j)
143
+ mark_dynamic(old_hidden_states_j)
144
+ mark_dynamic(input_ids_j)
145
+ mark_dynamic(mask_j)
146
+
147
+ grad_inputs_j.copy_(
148
+ accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
149
+ )
150
+ pass
151
+
152
+ grad_inputs .div_(n_chunks)
153
+ accumulated_loss .div_(n_chunks)
154
+ accumulated_completion_length.div_(n_chunks)
155
+ accumulated_mean_kl .div_(n_chunks)
156
+ ctx.save_for_backward(grad_inputs)
157
+
158
+ return (
159
+ accumulated_loss,
160
+ accumulated_completion_length,
161
+ accumulated_mean_kl,
162
+ )
163
+ pass
164
+
165
+ @staticmethod
166
+ def backward(ctx, grad_output, dcompletion_length, dmean_kl):
167
+ (grad_input,) = ctx.saved_tensors
168
+ return (grad_input, None, None, None, None, None, None, None, None,)
169
+ pass
170
+
171
+ def grpo_accumulated_loss(
172
+ trainer,
173
+ input_ids,
174
+ logits_to_keep,
175
+ completion_mask,
176
+ advantages,
177
+ n_chunks = -1,
178
+ ):
179
+ # All Unsloth Zoo code licensed under LGPLv3
180
+ bsz, qlen = input_ids.shape
181
+ # Find closest multiple
182
+ factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
183
+ if n_chunks == -1: n_chunks = bsz
184
+ n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
185
+
186
+ mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
187
+ os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
188
+
189
+ completion_input_ids = input_ids[:, -logits_to_keep:]
190
+ lm_head = trainer.model.get_output_embeddings().weight
191
+
192
+ with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
193
+ with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
194
+ old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
195
+ pass
196
+
197
+ new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
198
+
199
+ loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
200
+ new_hidden_states, old_hidden_states, lm_head,
201
+ completion_input_ids, completion_mask, advantages, trainer.beta,
202
+ trainer.accelerator.scaler,
203
+ n_chunks,
204
+ )
205
+ return loss, completion_length, mean_kl
206
+
207
+ # Old non efficient code path
208
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
209
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
210
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
211
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
212
+ loss, completion_length, mean_kl = grpo_compute_loss(
213
+ old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
214
+ )
215
+ return loss, completion_length, mean_kl
216
+ pass
217
+
218
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
219
+ def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
220
+ # All Unsloth Zoo code licensed under LGPLv3
221
+ old_logits = old_logits.to(torch.float32)
222
+ new_logits = new_logits.to(torch.float32)
223
+ input_ids = input_ids.unsqueeze(-1)
224
+
225
+ # x_i - logsumexp(x_i)
226
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
227
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
228
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
229
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
230
+
231
+ # Reverse KL
232
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
233
+ # Full correct reverse KL divergence?? Missing term maybe?
234
+ # kl_i = torch.exp(new) * kl_i
235
+
236
+ # Below is forward KL (normal KL)
237
+ # kl_i = torch.exp(old) * (old - new)
238
+
239
+ # Must detach - otherwise gradients are not propagated correctly!
240
+ # exp(x - x) == 1
241
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
242
+ loss_i = -(loss_i - beta * kl_i)
243
+
244
+ mask = mask.to(torch.float32)
245
+ n_mask_per_reward = mask.sum(1)
246
+
247
+ # See https://github.com/huggingface/trl/pull/2881
248
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
249
+ loss = loss_per_reward.mean()
250
+ # loss = (loss_i * mask).sum() / mask.sum()
251
+
252
+ # Get metrics as well which are folded
253
+ with torch.inference_mode():
254
+ completion_length = n_mask_per_reward.mean()
255
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
256
+ mean_kl = mean_kl_per_reward.mean()
257
+ pass
258
+ return loss, completion_length, mean_kl
259
+
260
+ def vLLMSamplingParams(**kwargs):
261
+ from vllm import SamplingParams
262
+ sampling_params = SamplingParams(**kwargs)
263
+ sampling_params._set_kwargs = kwargs
264
+ return sampling_params
265
+ @dataclass
266
+ class UnslothGRPOConfig(GRPOConfig):
267
+ """
268
+
269
+ Configuration class for the [`GRPOTrainer`].
270
+
271
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
272
+ [`~transformers.TrainingArguments`] documentation.
273
+
274
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
275
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
276
+ command line.
277
+
278
+ Parameters:
279
+ > Parameters that control the model and reference model
280
+
281
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
282
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
283
+ argument of the [`GRPOTrainer`] is provided as a string.
284
+
285
+ > Parameters that control the data preprocessing
286
+
287
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
288
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
289
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
290
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
291
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
292
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
293
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
294
+ must be divisible by this value.
295
+ temperature (`float`, *optional*, defaults to `0.9`):
296
+ Temperature for sampling. The higher the temperature, the more random the completions.
297
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
298
+ Maximum length of the generated completion.
299
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
300
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
301
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
302
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
303
+ with vLLM generation.
304
+
305
+ > Parameters that control generation acceleration powered by vLLM
306
+
307
+ use_vllm (`bool`, *optional*, defaults to `False`):
308
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
309
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
310
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
311
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
312
+ automatically select the next available GPU after the last one used for training. This assumes that
313
+ training has not already occupied all available GPUs. If only one device is available, the device will be
314
+ shared between both training and vLLM.
315
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
316
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
317
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
318
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
319
+ during initialization.
320
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
321
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
322
+ based on the model configuration. Find the supported values in the vLLM documentation.
323
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
324
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
325
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
326
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
327
+
328
+ > Parameters that control the training
329
+
330
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
331
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
332
+ [`~transformers.TrainingArguments`].
333
+ beta (`float`, *optional*, defaults to `0.04`):
334
+ KL coefficient.
335
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
336
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
337
+ weighted equally with weight `1.0`.
338
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
339
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
340
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
341
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
342
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
343
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
344
+ between the current policy and the previous reference policy during updates. The reference policy is
345
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
346
+ must set `sync_ref_model=True`.
347
+ ref_model_sync_steps (`int`, *optional*, defaults to `64`):
348
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
349
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
350
+ set `sync_ref_model=True`.
351
+
352
+ > Parameters that control the logging
353
+
354
+ log_completions (`bool`, *optional*, defaults to `False`):
355
+ Whether to log the completions during training.
356
+
357
+ """
358
+ vllm_sampling_params: Optional[Any] = field(
359
+ default = None,
360
+ metadata = {'help': 'vLLM SamplingParams'},
361
+ )
362
+ unsloth_num_chunks : Optional[int] = field(
363
+ default = -1,
364
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
365
+ )
366
+ def __init__(
367
+ self,
368
+ output_dir = None,
369
+ overwrite_output_dir = None,
370
+ do_train = False,
371
+ do_eval = False,
372
+ do_predict = False,
373
+ eval_strategy = 'no',
374
+ prediction_loss_only = False,
375
+ per_device_train_batch_size = 4,
376
+ per_device_eval_batch_size = 4,
377
+ per_gpu_train_batch_size = None,
378
+ per_gpu_eval_batch_size = None,
379
+ gradient_accumulation_steps = 2,
380
+ eval_accumulation_steps = 2,
381
+ eval_delay = 0,
382
+ torch_empty_cache_steps = 250,
383
+ learning_rate = 5e-05,
384
+ weight_decay = 0.01,
385
+ adam_beta1 = 0.9,
386
+ adam_beta2 = 0.999,
387
+ adam_epsilon = 1e-08,
388
+ max_grad_norm = 1.0,
389
+ num_train_epochs = 3.0,
390
+ max_steps = -1,
391
+ lr_scheduler_type = 'linear',
392
+ warmup_ratio = 0.1,
393
+ warmup_steps = 0,
394
+ log_level = 'passive',
395
+ log_level_replica = 'warning',
396
+ log_on_each_node = True,
397
+ logging_dir = None,
398
+ logging_strategy = 'steps',
399
+ logging_first_step = False,
400
+ logging_steps = 1,
401
+ logging_nan_inf_filter = False,
402
+ save_strategy = 'steps',
403
+ save_steps = 500,
404
+ save_total_limit = None,
405
+ save_safetensors = True,
406
+ save_on_each_node = False,
407
+ save_only_model = False,
408
+ restore_callback_states_from_checkpoint = False,
409
+ no_cuda = False,
410
+ use_cpu = False,
411
+ use_mps_device = False,
412
+ seed = 3407,
413
+ data_seed = 3407,
414
+ jit_mode_eval = False,
415
+ use_ipex = False,
416
+ bf16 = False,
417
+ fp16 = False,
418
+ fp16_opt_level = 'O1',
419
+ half_precision_backend = 'auto',
420
+ bf16_full_eval = False,
421
+ fp16_full_eval = False,
422
+ tf32 = None,
423
+ local_rank = -1,
424
+ ddp_backend = None,
425
+ tpu_num_cores = None,
426
+ tpu_metrics_debug = False,
427
+ debug = '',
428
+ dataloader_drop_last = False,
429
+ eval_steps = None,
430
+ dataloader_num_workers = 0,
431
+ dataloader_prefetch_factor = None,
432
+ past_index = -1,
433
+ run_name = None,
434
+ disable_tqdm = None,
435
+ remove_unused_columns = False,
436
+ label_names = None,
437
+ load_best_model_at_end = False,
438
+ metric_for_best_model = None,
439
+ greater_is_better = None,
440
+ ignore_data_skip = False,
441
+ fsdp = '',
442
+ fsdp_min_num_params = 0,
443
+ fsdp_config = None,
444
+ tp_size = 0,
445
+ fsdp_transformer_layer_cls_to_wrap = None,
446
+ accelerator_config = None,
447
+ deepspeed = None,
448
+ label_smoothing_factor = 0.0,
449
+ optim = 'adamw_8bit',
450
+ optim_args = None,
451
+ adafactor = False,
452
+ group_by_length = False,
453
+ length_column_name = 'length',
454
+ report_to = None,
455
+ ddp_find_unused_parameters = None,
456
+ ddp_bucket_cap_mb = None,
457
+ ddp_broadcast_buffers = None,
458
+ dataloader_pin_memory = True,
459
+ dataloader_persistent_workers = False,
460
+ skip_memory_metrics = True,
461
+ use_legacy_prediction_loop = False,
462
+ push_to_hub = False,
463
+ resume_from_checkpoint = None,
464
+ hub_model_id = None,
465
+ hub_strategy = 'every_save',
466
+ hub_token = None,
467
+ hub_private_repo = None,
468
+ hub_always_push = False,
469
+ gradient_checkpointing = False,
470
+ gradient_checkpointing_kwargs = None,
471
+ include_inputs_for_metrics = False,
472
+ eval_do_concat_batches = True,
473
+ fp16_backend = 'auto',
474
+ push_to_hub_model_id = None,
475
+ push_to_hub_organization = None,
476
+ push_to_hub_token = None,
477
+ mp_parameters = '',
478
+ auto_find_batch_size = False,
479
+ full_determinism = False,
480
+ torchdynamo = None,
481
+ ray_scope = 'last',
482
+ ddp_timeout = 1800,
483
+ torch_compile = False,
484
+ torch_compile_backend = None,
485
+ torch_compile_mode = None,
486
+ include_tokens_per_second = False,
487
+ include_num_input_tokens_seen = False,
488
+ neftune_noise_alpha = None,
489
+ optim_target_modules = None,
490
+ batch_eval_metrics = False,
491
+ eval_on_start = False,
492
+ use_liger_kernel = False,
493
+ eval_use_gather_object = False,
494
+ average_tokens_across_devices = False,
495
+ model_init_kwargs = None,
496
+ max_prompt_length = 512,
497
+ num_generations = 8,
498
+ temperature = 0.9,
499
+ max_completion_length = 256,
500
+ ds3_gather_for_generation = True,
501
+ use_vllm = False,
502
+ vllm_device = 'auto',
503
+ vllm_gpu_memory_utilization = 0.9,
504
+ vllm_dtype = 'auto',
505
+ vllm_max_model_len = None,
506
+ beta = 0.04,
507
+ reward_weights = None,
508
+ sync_ref_model = False,
509
+ ref_model_mixup_alpha = 0.9,
510
+ ref_model_sync_steps = 64,
511
+ log_completions = False,
512
+ vllm_sampling_params = None,
513
+ unsloth_num_chunks = -1,
514
+ **kwargs,
515
+ ):
516
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
517
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
518
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
519
+ output_dir = 'unsloth_training_checkpoints'
520
+ save_strategy = 'no'
521
+ div = per_device_train_batch_size // num_generations
522
+ if div * num_generations != per_device_train_batch_size:
523
+ print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
524
+ per_device_train_batch_size = num_generations
525
+
526
+ super().__init__(
527
+ output_dir = output_dir,
528
+ overwrite_output_dir = overwrite_output_dir,
529
+ do_train = do_train,
530
+ do_eval = do_eval,
531
+ do_predict = do_predict,
532
+ eval_strategy = eval_strategy,
533
+ prediction_loss_only = prediction_loss_only,
534
+ per_device_train_batch_size = per_device_train_batch_size,
535
+ per_device_eval_batch_size = per_device_eval_batch_size,
536
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
537
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
538
+ gradient_accumulation_steps = gradient_accumulation_steps,
539
+ eval_accumulation_steps = eval_accumulation_steps,
540
+ eval_delay = eval_delay,
541
+ torch_empty_cache_steps = torch_empty_cache_steps,
542
+ learning_rate = learning_rate,
543
+ weight_decay = weight_decay,
544
+ adam_beta1 = adam_beta1,
545
+ adam_beta2 = adam_beta2,
546
+ adam_epsilon = adam_epsilon,
547
+ max_grad_norm = max_grad_norm,
548
+ num_train_epochs = num_train_epochs,
549
+ max_steps = max_steps,
550
+ lr_scheduler_type = lr_scheduler_type,
551
+ warmup_ratio = warmup_ratio,
552
+ warmup_steps = warmup_steps,
553
+ log_level = log_level,
554
+ log_level_replica = log_level_replica,
555
+ log_on_each_node = log_on_each_node,
556
+ logging_dir = logging_dir,
557
+ logging_strategy = logging_strategy,
558
+ logging_first_step = logging_first_step,
559
+ logging_steps = logging_steps,
560
+ logging_nan_inf_filter = logging_nan_inf_filter,
561
+ save_strategy = save_strategy,
562
+ save_steps = save_steps,
563
+ save_total_limit = save_total_limit,
564
+ save_safetensors = save_safetensors,
565
+ save_on_each_node = save_on_each_node,
566
+ save_only_model = save_only_model,
567
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
568
+ no_cuda = no_cuda,
569
+ use_cpu = use_cpu,
570
+ use_mps_device = use_mps_device,
571
+ seed = seed,
572
+ data_seed = data_seed,
573
+ jit_mode_eval = jit_mode_eval,
574
+ use_ipex = use_ipex,
575
+ bf16 = bf16,
576
+ fp16 = fp16,
577
+ fp16_opt_level = fp16_opt_level,
578
+ half_precision_backend = half_precision_backend,
579
+ bf16_full_eval = bf16_full_eval,
580
+ fp16_full_eval = fp16_full_eval,
581
+ tf32 = tf32,
582
+ local_rank = local_rank,
583
+ ddp_backend = ddp_backend,
584
+ tpu_num_cores = tpu_num_cores,
585
+ tpu_metrics_debug = tpu_metrics_debug,
586
+ debug = debug,
587
+ dataloader_drop_last = dataloader_drop_last,
588
+ eval_steps = eval_steps,
589
+ dataloader_num_workers = dataloader_num_workers,
590
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
591
+ past_index = past_index,
592
+ run_name = run_name,
593
+ disable_tqdm = disable_tqdm,
594
+ remove_unused_columns = remove_unused_columns,
595
+ label_names = label_names,
596
+ load_best_model_at_end = load_best_model_at_end,
597
+ metric_for_best_model = metric_for_best_model,
598
+ greater_is_better = greater_is_better,
599
+ ignore_data_skip = ignore_data_skip,
600
+ fsdp = fsdp,
601
+ fsdp_min_num_params = fsdp_min_num_params,
602
+ fsdp_config = fsdp_config,
603
+ tp_size = tp_size,
604
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
605
+ accelerator_config = accelerator_config,
606
+ deepspeed = deepspeed,
607
+ label_smoothing_factor = label_smoothing_factor,
608
+ optim = optim,
609
+ optim_args = optim_args,
610
+ adafactor = adafactor,
611
+ group_by_length = group_by_length,
612
+ length_column_name = length_column_name,
613
+ report_to = report_to,
614
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
615
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
616
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
617
+ dataloader_pin_memory = dataloader_pin_memory,
618
+ dataloader_persistent_workers = dataloader_persistent_workers,
619
+ skip_memory_metrics = skip_memory_metrics,
620
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
621
+ push_to_hub = push_to_hub,
622
+ resume_from_checkpoint = resume_from_checkpoint,
623
+ hub_model_id = hub_model_id,
624
+ hub_strategy = hub_strategy,
625
+ hub_token = hub_token,
626
+ hub_private_repo = hub_private_repo,
627
+ hub_always_push = hub_always_push,
628
+ gradient_checkpointing = gradient_checkpointing,
629
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
630
+ include_inputs_for_metrics = include_inputs_for_metrics,
631
+ eval_do_concat_batches = eval_do_concat_batches,
632
+ fp16_backend = fp16_backend,
633
+ push_to_hub_model_id = push_to_hub_model_id,
634
+ push_to_hub_organization = push_to_hub_organization,
635
+ push_to_hub_token = push_to_hub_token,
636
+ mp_parameters = mp_parameters,
637
+ auto_find_batch_size = auto_find_batch_size,
638
+ full_determinism = full_determinism,
639
+ torchdynamo = torchdynamo,
640
+ ray_scope = ray_scope,
641
+ ddp_timeout = ddp_timeout,
642
+ torch_compile = torch_compile,
643
+ torch_compile_backend = torch_compile_backend,
644
+ torch_compile_mode = torch_compile_mode,
645
+ include_tokens_per_second = include_tokens_per_second,
646
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
647
+ neftune_noise_alpha = neftune_noise_alpha,
648
+ optim_target_modules = optim_target_modules,
649
+ batch_eval_metrics = batch_eval_metrics,
650
+ eval_on_start = eval_on_start,
651
+ use_liger_kernel = use_liger_kernel,
652
+ eval_use_gather_object = eval_use_gather_object,
653
+ average_tokens_across_devices = average_tokens_across_devices,
654
+ model_init_kwargs = model_init_kwargs,
655
+ max_prompt_length = max_prompt_length,
656
+ num_generations = num_generations,
657
+ temperature = temperature,
658
+ max_completion_length = max_completion_length,
659
+ ds3_gather_for_generation = ds3_gather_for_generation,
660
+ use_vllm = use_vllm,
661
+ vllm_device = vllm_device,
662
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
663
+ vllm_dtype = vllm_dtype,
664
+ vllm_max_model_len = vllm_max_model_len,
665
+ beta = beta,
666
+ reward_weights = reward_weights,
667
+ sync_ref_model = sync_ref_model,
668
+ ref_model_mixup_alpha = ref_model_mixup_alpha,
669
+ ref_model_sync_steps = ref_model_sync_steps,
670
+ log_completions = log_completions,**kwargs)
671
+ self.vllm_sampling_params = vllm_sampling_params
672
+ self.unsloth_num_chunks = unsloth_num_chunks
673
+ pass
674
+
675
+ class _UnslothGRPOTrainer(Trainer):
676
+ """"""
677
+
678
+ _tag_names = ["trl", "grpo"]
679
+
680
+ def __init__(
681
+ self,
682
+ model: Union[str, PreTrainedModel],
683
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
684
+ args: GRPOConfig = None,
685
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
686
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
687
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
688
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
689
+ callbacks: Optional[list[TrainerCallback]] = None,
690
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
691
+ peft_config: Optional["PeftConfig"] = None,
692
+ ):
693
+
694
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
695
+ # Args
696
+ if args is None:
697
+ model_name = model if isinstance(model, str) else model.config._name_or_path
698
+ model_name = model_name.split("/")[-1]
699
+ args = GRPOConfig(f"{model_name}-GRPO")
700
+
701
+ # Models
702
+ # Trained model
703
+ model_init_kwargs = args.model_init_kwargs or {}
704
+ if isinstance(model, str):
705
+ model_id = model
706
+ torch_dtype = model_init_kwargs.get("torch_dtype")
707
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
708
+ pass # torch_dtype is already a torch.dtype or "auto" or None
709
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
710
+ torch_dtype = getattr(torch, torch_dtype)
711
+ model_init_kwargs["torch_dtype"] = torch_dtype
712
+ else:
713
+ raise ValueError(
714
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
715
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
716
+ )
717
+ # Disable caching if gradient checkpointing is enabled (not supported)
718
+ model_init_kwargs["use_cache"] = (
719
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
720
+ )
721
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
722
+ else:
723
+ model_id = model.config._name_or_path
724
+ if args.model_init_kwargs is not None:
725
+ raise ValueError(
726
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
727
+ "This argument can only be used when the `model` argument is a string."
728
+ )
729
+
730
+ if False:
731
+ model = model
732
+
733
+ # Reference model
734
+ if is_deepspeed_zero3_enabled():
735
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
736
+ elif not is_peft_model(model):
737
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
738
+ self.ref_model = create_reference_model(model)
739
+ else:
740
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
741
+ # to revert to the initial model.
742
+ self.ref_model = None
743
+
744
+ # Processing class
745
+ if processing_class is None:
746
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
747
+
748
+ # Reward functions
749
+ if not isinstance(reward_funcs, list):
750
+ reward_funcs = [reward_funcs]
751
+ for i, reward_func in enumerate(reward_funcs):
752
+ if isinstance(reward_func, str):
753
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
754
+ reward_func, num_labels=1, **model_init_kwargs
755
+ )
756
+ self.reward_funcs = reward_funcs
757
+
758
+ # Reward weights
759
+ if args.reward_weights is not None:
760
+ if len(args.reward_weights) != len(reward_funcs):
761
+ raise ValueError(
762
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
763
+ f"functions ({len(reward_funcs)})"
764
+ )
765
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
766
+ else:
767
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
768
+
769
+ # Reward processing class
770
+ if reward_processing_classes is None:
771
+ reward_processing_classes = [None] * len(reward_funcs)
772
+ elif not isinstance(reward_processing_classes, list):
773
+ reward_processing_classes = [reward_processing_classes]
774
+ else:
775
+ if len(reward_processing_classes) != len(reward_funcs):
776
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
777
+
778
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
779
+ if isinstance(reward_func, PreTrainedModel):
780
+ if reward_processing_class is None:
781
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
782
+ if reward_processing_class.pad_token_id is None:
783
+ reward_processing_class.pad_token = reward_processing_class.eos_token
784
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
785
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
786
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
787
+ reward_processing_classes[i] = reward_processing_class
788
+ self.reward_processing_classes = reward_processing_classes
789
+
790
+ # Data collator
791
+ def data_collator(features): # No data collation is needed in GRPO
792
+ return features
793
+
794
+ # Training arguments
795
+ self.max_prompt_length = args.max_prompt_length
796
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
797
+ self.num_generations = args.num_generations # = G in the GRPO paper
798
+ self.use_vllm = args.use_vllm
799
+
800
+ self.beta = args.beta
801
+
802
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
803
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
804
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
805
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
806
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
807
+ # This acts as a flag to indicate that the warning has already been issued.
808
+ model.warnings_issued["estimate_tokens"] = True
809
+
810
+ # Initialize the metrics
811
+ self._metrics = defaultdict(list)
812
+ self.log_completions = args.log_completions
813
+
814
+ super().__init__(
815
+ model=model,
816
+ args=args,
817
+ data_collator=data_collator,
818
+ train_dataset=train_dataset,
819
+ eval_dataset=eval_dataset,
820
+ processing_class=processing_class,
821
+ callbacks=callbacks,
822
+ optimizers=optimizers,
823
+ )
824
+
825
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
826
+ num_processes = self.accelerator.num_processes
827
+ global_batch_size = args.per_device_train_batch_size * num_processes
828
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
829
+ if self.num_generations not in possible_values:
830
+ raise ValueError(
831
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
832
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
833
+ f"batch size, the valid values for the number of generations are: {possible_values}."
834
+ )
835
+ if self.args.eval_strategy != "no":
836
+ global_batch_size = args.per_device_eval_batch_size * num_processes
837
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
838
+ if self.num_generations not in possible_values:
839
+ raise ValueError(
840
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
841
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
842
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
843
+ )
844
+
845
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
846
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
847
+ # it's safer to set it in all cases.
848
+ set_seed(args.seed, device_specific=True)
849
+
850
+ if self.use_vllm:
851
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
852
+ temperature=args.temperature,
853
+ max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
854
+ else:
855
+ self.generation_config = GenerationConfig(
856
+ max_new_tokens=self.max_completion_length,
857
+ do_sample=True,
858
+ temperature=args.temperature,
859
+ pad_token_id=processing_class.pad_token_id,
860
+ )
861
+
862
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
863
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
864
+ # self.model_accepts_loss_kwargs to False to enable scaling.
865
+ self.model_accepts_loss_kwargs = False
866
+
867
+ # Add tags to the model
868
+ self.model.add_model_tags(self._tag_names)
869
+
870
+ if self.ref_model is not None:
871
+ if self.is_deepspeed_enabled:
872
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
873
+ else:
874
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
875
+
876
+ if args.sync_ref_model:
877
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
878
+
879
+ for i, reward_func in enumerate(self.reward_funcs):
880
+ if isinstance(reward_func, PreTrainedModel):
881
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
882
+
883
+ def _set_signature_columns_if_needed(self):
884
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
885
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
886
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
887
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
888
+ if self._signature_columns is None:
889
+ self._signature_columns = ["prompt"]
890
+
891
+ def _get_train_sampler(self) -> Sampler:
892
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
893
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
894
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
895
+ # preventing discrepancies in group formation.
896
+ return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
897
+
898
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
899
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
900
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
901
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
902
+ # preventing discrepancies in group formation.
903
+ return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
904
+
905
+ # Get the per-token log probabilities for the completions for the model and the reference model
906
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
907
+ if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
908
+ return None # Unsloth efficient GRPO
909
+ # Otherwise, calculate normally:
910
+ if not hasattr(self, '_autocast_dtype'):
911
+ self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
912
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
913
+ with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
914
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
915
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
916
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
917
+
918
+ input_ids = input_ids[:, -logits_to_keep:]
919
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
920
+ # See https://github.com/huggingface/trl/issues/2770
921
+ logits = logits[:, -logits_to_keep:]
922
+ return logits
923
+ # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
924
+ pass
925
+
926
+ def _move_model_to_vllm(self, *args, **kwargs): return None
927
+
928
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
929
+ device = self.accelerator.device
930
+ prompts = [x["prompt"] for x in inputs]
931
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
932
+ prompt_inputs = self.processing_class(
933
+ prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
934
+ )
935
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
936
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
937
+
938
+ if self.max_prompt_length is not None:
939
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
940
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
941
+
942
+ # Generate completions using either vLLM or regular generation
943
+ if self.args.use_vllm:
944
+ # First, have main process load weights if needed
945
+ if self.state.global_step != self._last_loaded_step:
946
+ self._move_model_to_vllm()
947
+ self._last_loaded_step = self.state.global_step
948
+
949
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
950
+ all_prompts_text = gather_object(prompts_text)
951
+ if self.accelerator.is_main_process:
952
+ outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
953
+ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
954
+ else:
955
+ completion_ids = [None] * len(all_prompts_text)
956
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
957
+ # corresponding slice.
958
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
959
+ process_slice = slice(
960
+ self.accelerator.process_index * len(prompts),
961
+ (self.accelerator.process_index + 1) * len(prompts),
962
+ )
963
+ completion_ids = completion_ids[process_slice]
964
+
965
+ # Pad the completions, and concatenate them with the prompts
966
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
967
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
968
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
969
+ else:
970
+ # Regular generation path
971
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
972
+ prompt_completion_ids = unwrapped_model.generate(
973
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
974
+ )
975
+
976
+ # Compute prompt length and extract completion ids
977
+ prompt_length = prompt_ids.size(1)
978
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
979
+ completion_ids = prompt_completion_ids[:, prompt_length:]
980
+
981
+ # Mask everything after the first EOS token
982
+ is_eos = completion_ids == self.processing_class.eos_token_id
983
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
984
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
985
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
986
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
987
+
988
+ # Concatenate prompt_mask with completion_mask for logit computation
989
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
990
+
991
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
992
+
993
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
994
+ if self.ref_model is not None:
995
+ ref_per_token_logps = self._get_per_token_logps(
996
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
997
+ )
998
+ else:
999
+ with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
1000
+ ref_per_token_logps = self._get_per_token_logps(
1001
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
1002
+ )
1003
+
1004
+ # Decode the generated completions
1005
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1006
+ if is_conversational(inputs[0]):
1007
+ completions = []
1008
+ for prompt, completion in zip(prompts, completions_text):
1009
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1010
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
1011
+ else:
1012
+ completions = completions_text
1013
+
1014
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1015
+ for i, (reward_func, reward_processing_class) in enumerate(
1016
+ zip(self.reward_funcs, self.reward_processing_classes)
1017
+ ):
1018
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1019
+ if is_conversational(inputs[0]):
1020
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1021
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1022
+ else:
1023
+ texts = [p + c for p, c in zip(prompts, completions)]
1024
+ reward_inputs = reward_processing_class(
1025
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1026
+ )
1027
+ reward_inputs = super()._prepare_inputs(reward_inputs)
1028
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1029
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1030
+ else:
1031
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1032
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
1033
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1034
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1035
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1036
+
1037
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1038
+ # completions may be distributed across processes
1039
+ rewards_per_func = gather(rewards_per_func)
1040
+
1041
+ # Apply weights to each reward function's output and sum
1042
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
1043
+
1044
+ # Compute grouped-wise rewards
1045
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1046
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1047
+
1048
+ # Normalize the rewards to compute the advantages
1049
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1050
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1051
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1052
+
1053
+ # Slice to keep only the local part of the data
1054
+ process_slice = slice(
1055
+ self.accelerator.process_index * len(prompts),
1056
+ (self.accelerator.process_index + 1) * len(prompts),
1057
+ )
1058
+ advantages = advantages[process_slice]
1059
+
1060
+ # Log the metrics
1061
+ reward_per_func = rewards_per_func.mean(0)
1062
+ for i, reward_func in enumerate(self.reward_funcs):
1063
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1064
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1065
+ else:
1066
+ reward_func_name = reward_func.__name__
1067
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1068
+
1069
+ self._metrics["reward"].append(rewards.mean().item())
1070
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
1071
+
1072
+ if (
1073
+ self.log_completions
1074
+ and self.state.global_step % self.args.logging_steps == 0
1075
+ and "wandb" in self.args.report_to
1076
+ ):
1077
+ import pandas as pd
1078
+
1079
+ # For logging
1080
+ table = {
1081
+ "step": [str(self.state.global_step)] * len(rewards),
1082
+ "prompt": gather_object(prompts_text),
1083
+ "completion": gather_object(completions_text),
1084
+ "reward": rewards.tolist(),
1085
+ }
1086
+ df = pd.DataFrame(table)
1087
+
1088
+ if wandb.run is not None and self.accelerator.is_main_process:
1089
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1090
+
1091
+ return {
1092
+ "prompt_ids": prompt_ids,
1093
+ "prompt_mask": prompt_mask,
1094
+ "completion_ids": completion_ids,
1095
+ "completion_mask": completion_mask,
1096
+ "ref_per_token_logps": ref_per_token_logps,
1097
+ "advantages": advantages,
1098
+ }
1099
+
1100
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
1101
+ if return_outputs:
1102
+ raise ValueError("The GRPOTrainer does not support returning outputs")
1103
+ # Compute the per-token log probabilities for the model
1104
+
1105
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1106
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1107
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1108
+ bsz, qlen = input_ids.shape
1109
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1110
+ # attention_mask = None
1111
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1112
+ _input_ids = input_ids
1113
+ _logits_to_keep = logits_to_keep
1114
+
1115
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
1116
+
1117
+ # Compute the KL divergence between the model and the reference model
1118
+ ref_per_token_logps = inputs["ref_per_token_logps"]
1119
+ # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
1120
+
1121
+ # x - x.detach() allows for preserving gradients from x
1122
+ advantages = inputs["advantages"]
1123
+ # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1124
+ # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1125
+ # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1126
+ input_ids = input_ids[:, -logits_to_keep:]
1127
+ if per_token_logps is not None:
1128
+ loss, completion_length, mean_kl = grpo_compute_loss_slow(
1129
+ ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
1130
+ )
1131
+ else:
1132
+ loss, completion_length, mean_kl = grpo_accumulated_loss(
1133
+ self, _input_ids, logits_to_keep, completion_mask, advantages,
1134
+ n_chunks = self.args.unsloth_num_chunks,
1135
+ )
1136
+
1137
+ # Log the metrics
1138
+ # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1139
+
1140
+ # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1141
+ # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1142
+
1143
+ if "train" in self._metrics:
1144
+ mode = "eval" if self.control.should_evaluate else "train"
1145
+ self._metrics[mode]["completion_length"].append(completion_length.item())
1146
+ self._metrics[mode]["kl"].append(mean_kl.item())
1147
+ else:
1148
+ self._metrics["completion_length"].append(completion_length.item())
1149
+ self._metrics["kl"].append(mean_kl.item())
1150
+ return loss
1151
+
1152
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
1153
+ inputs = self._prepare_inputs(inputs)
1154
+ with torch.no_grad():
1155
+ with self.compute_loss_context_manager():
1156
+ loss = self.compute_loss(model, inputs)
1157
+ loss = loss.mean().detach()
1158
+ return loss, None, None
1159
+
1160
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1161
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1162
+
1163
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1164
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1165
+ if next(iter(logs.keys())).startswith("eval_"):
1166
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1167
+
1168
+ logs = {**logs, **metrics}
1169
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1170
+ super().log(logs, start_time)
1171
+ else: # transformers<=4.46
1172
+ super().log(logs)
1173
+ self._metrics.clear()
1174
+
1175
+ def create_model_card(
1176
+ self,
1177
+ model_name: Optional[str] = None,
1178
+ dataset_name: Optional[str] = None,
1179
+ tags: Union[str, list[str], None] = None,
1180
+ ):
1181
+ """
1182
+ Creates a draft of a model card using the information available to the `Trainer`.
1183
+
1184
+ Args:
1185
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1186
+ Name of the model.
1187
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1188
+ Name of the dataset used for training.
1189
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1190
+ Tags to be associated with the model card.
1191
+ """
1192
+ if not self.is_world_process_zero():
1193
+ return
1194
+
1195
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1196
+ base_model = self.model.config._name_or_path
1197
+ else:
1198
+ base_model = None
1199
+
1200
+ tags = tags or []
1201
+ if isinstance(tags, str):
1202
+ tags = [tags]
1203
+
1204
+ if hasattr(self.model.config, "unsloth_version"):
1205
+ tags.append("unsloth")
1206
+
1207
+ citation = textwrap.dedent(
1208
+ """\
1209
+ @article{zhihong2024deepseekmath,
1210
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1211
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1212
+ year = 2024,
1213
+ eprint = {arXiv:2402.03300},
1214
+ }
1215
+ """
1216
+ )
1217
+
1218
+ model_card = generate_model_card(
1219
+ base_model=base_model,
1220
+ model_name=model_name,
1221
+ hub_model_id=self.hub_model_id,
1222
+ dataset_name=dataset_name,
1223
+ tags=tags,
1224
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1225
+ comet_url=get_comet_experiment_url(),
1226
+ trainer_name="GRPO",
1227
+ trainer_citation=citation,
1228
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1229
+ paper_id="2402.03300",
1230
+ )
1231
+
1232
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1233
+ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
1234
+ """
1235
+
1236
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
1237
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
1238
+
1239
+ Example:
1240
+
1241
+ ```python
1242
+ from datasets import load_dataset
1243
+ from trl import GRPOTrainer
1244
+
1245
+ dataset = load_dataset("trl-lib/tldr", split="train")
1246
+
1247
+ def reward_func(completions, **kwargs):
1248
+ # Dummy reward function that rewards completions with more unique letters.
1249
+ return [float(len(set(completion))) for completion in completions]
1250
+
1251
+ trainer = GRPOTrainer(
1252
+ model="Qwen/Qwen2-0.5B-Instruct",
1253
+ reward_funcs=reward_func,
1254
+ train_dataset=dataset,
1255
+ )
1256
+
1257
+ trainer.train()
1258
+ ```
1259
+
1260
+ Args:
1261
+ model (`Union[str, PreTrainedModel]`):
1262
+ Model to be trained. Can be either:
1263
+
1264
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
1265
+ a path to a *directory* containing model weights saved using
1266
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
1267
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
1268
+ in `args.model_init_kwargs`.
1269
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1270
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
1271
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
1272
+ functions with the prompts and completions and sum the rewards. Can be either:
1273
+
1274
+ - A single reward function, such as:
1275
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
1276
+ path to a *directory* containing model weights saved using
1277
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1278
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
1279
+ keyword arguments in `args.model_init_kwargs`.
1280
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
1281
+ - A custom reward function: The function is provided with the prompts and the generated completions,
1282
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
1283
+ [Using a custom reward function](#using-a-custom-reward-function).
1284
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
1285
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
1286
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
1287
+ Configuration for this trainer. If `None`, a default configuration is used.
1288
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1289
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
1290
+ ignored. The format of the samples can be either:
1291
+
1292
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1293
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1294
+ and content).
1295
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1296
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1297
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
1298
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
1299
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
1300
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
1301
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
1302
+
1303
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
1304
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
1305
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
1306
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
1307
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
1308
+ the corresponding entries in `reward_processing_classes` are ignored.
1309
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1310
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
1311
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1312
+
1313
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1314
+ method.
1315
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1316
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1317
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1318
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1319
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1320
+
1321
+ """
1322
+ def __init__(
1323
+ self,
1324
+ model,
1325
+ reward_funcs,
1326
+ args = None,
1327
+ train_dataset = None,
1328
+ eval_dataset = None,
1329
+ processing_class = None,
1330
+ reward_processing_classes = None,
1331
+ callbacks = None,
1332
+ peft_config = None,
1333
+ **kwargs
1334
+ ):
1335
+ if args is None: args = UnslothGRPOConfig()
1336
+ use_bf16 = getattr(args, 'bf16', False)
1337
+ use_fp16 = getattr(args, 'fp16', False)
1338
+ force_float32 = False
1339
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1340
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1341
+ force_float32 = True
1342
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1343
+ dtype = getattr(model.config, 'torch_dtype', None)
1344
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1345
+ from unsloth_zoo.utils import _get_dtype
1346
+ dtype = _get_dtype(dtype)
1347
+ float16 = dtype == torch.float16
1348
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1349
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1350
+ if force_float32:
1351
+ args.fp16 = False
1352
+ args.bf16 = False
1353
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1354
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1355
+ args.fp16 = float16
1356
+ args.bf16 = not float16
1357
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1358
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1359
+ args.eval_strategy = 'steps'
1360
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1361
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1362
+ if ga_steps is not None and ga_steps > 1:
1363
+ from transformers import __version__ as transformers_version
1364
+ if Version(transformers_version) <= Version('4.45.2'):
1365
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1366
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1367
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1368
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1369
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1370
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1371
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1372
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1373
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1374
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1375
+ if force_float32:
1376
+ args.bf16_full_eval = False
1377
+ args.fp16_full_eval = False
1378
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1379
+ args.bf16_full_eval = True
1380
+ args.fp16_full_eval = False
1381
+ elif not bf16_full_eval and not fp16_full_eval:
1382
+ args.bf16_full_eval = args.bf16
1383
+ args.fp16_full_eval = args.fp16
1384
+ _output_logits = False
1385
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1386
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1387
+ if _output_logits:
1388
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1389
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1390
+ pass
1391
+ else:
1392
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1393
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1394
+ if args_max_seq_length is None and model_max_seq_length is not None:
1395
+ max_seq_length = model.max_seq_length
1396
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1397
+ if model is not None and hasattr(model, 'for_training'):
1398
+ model.for_training()
1399
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1400
+ if 'processing_class' in locals():
1401
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1402
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1403
+ other_metrics = []
1404
+ if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
1405
+ else: _reward_funcs = reward_funcs
1406
+ for reward_func in _reward_funcs:
1407
+ try:
1408
+ reward_func_name = reward_func.__name__
1409
+ other_metrics.append(f'rewards/{reward_func_name}')
1410
+ except: pass
1411
+
1412
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1413
+ PatchRLStatistics('grpo_trainer', other_metrics)
1414
+
1415
+ super().__init__(
1416
+ model = model,
1417
+ reward_funcs = reward_funcs,
1418
+ args = args,
1419
+ train_dataset = train_dataset,
1420
+ eval_dataset = eval_dataset,
1421
+ processing_class = processing_class,
1422
+ reward_processing_classes = reward_processing_classes,
1423
+ callbacks = callbacks,
1424
+ peft_config = peft_config,**kwargs)
1425
+ if hasattr(self, 'neftune_hook_handle'):
1426
+ self.neftune_hook_handle.remove()
1427
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1428
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1429
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1430
+ pass
1431
+
1432
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
@@ -0,0 +1,1834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothKTOConfig(KTOConfig):
44
+ """
45
+
46
+ Configuration class for the [`KTOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model.
67
+ loss_type (`str`, *optional*, defaults to `"kto"`):
68
+ Type of loss to use. Possible values are:
69
+
70
+ - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
71
+ - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
72
+
73
+ desirable_weight (`float`, *optional*, defaults to `1.0`):
74
+ Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
75
+ undesirable_weight (`float`, *optional*, defaults to `1.0`):
76
+ Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
77
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
+ Label pad token id. This argument is required if you want to use the default data collator.
79
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
80
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
81
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
+ This argument is required if you want to use the default data collator.
84
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
85
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
86
+ evaluation.
87
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
+ you need to specify if the model returned by the callable is an encoder-decoder model.
90
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
91
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
92
+ useful when training without the reference model to reduce the total GPU memory needed.
93
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
94
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
95
+ string.
96
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
98
+ from a string.
99
+ dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+ disable_dropout (`bool`, *optional*, defaults to `True`):
102
+ Whether to disable dropout in the model and reference model.
103
+
104
+ """
105
+ vllm_sampling_params: Optional[Any] = field(
106
+ default = None,
107
+ metadata = {'help': 'vLLM SamplingParams'},
108
+ )
109
+ unsloth_num_chunks : Optional[int] = field(
110
+ default = -1,
111
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
112
+ )
113
+ def __init__(
114
+ self,
115
+ output_dir = None,
116
+ overwrite_output_dir = None,
117
+ do_train = False,
118
+ do_eval = False,
119
+ do_predict = False,
120
+ eval_strategy = 'no',
121
+ prediction_loss_only = False,
122
+ per_device_train_batch_size = 4,
123
+ per_device_eval_batch_size = 4,
124
+ per_gpu_train_batch_size = None,
125
+ per_gpu_eval_batch_size = None,
126
+ gradient_accumulation_steps = 2,
127
+ eval_accumulation_steps = 2,
128
+ eval_delay = 0,
129
+ torch_empty_cache_steps = 250,
130
+ learning_rate = 5e-05,
131
+ weight_decay = 0.01,
132
+ adam_beta1 = 0.9,
133
+ adam_beta2 = 0.999,
134
+ adam_epsilon = 1e-08,
135
+ max_grad_norm = 1.0,
136
+ num_train_epochs = 3.0,
137
+ max_steps = -1,
138
+ lr_scheduler_type = 'linear',
139
+ warmup_ratio = 0.1,
140
+ warmup_steps = 0,
141
+ log_level = 'passive',
142
+ log_level_replica = 'warning',
143
+ log_on_each_node = True,
144
+ logging_dir = None,
145
+ logging_strategy = 'steps',
146
+ logging_first_step = False,
147
+ logging_steps = 1,
148
+ logging_nan_inf_filter = False,
149
+ save_strategy = 'steps',
150
+ save_steps = 500,
151
+ save_total_limit = None,
152
+ save_safetensors = True,
153
+ save_on_each_node = False,
154
+ save_only_model = False,
155
+ restore_callback_states_from_checkpoint = False,
156
+ no_cuda = False,
157
+ use_cpu = False,
158
+ use_mps_device = False,
159
+ seed = 3407,
160
+ data_seed = 3407,
161
+ jit_mode_eval = False,
162
+ use_ipex = False,
163
+ bf16 = False,
164
+ fp16 = False,
165
+ fp16_opt_level = 'O1',
166
+ half_precision_backend = 'auto',
167
+ bf16_full_eval = False,
168
+ fp16_full_eval = False,
169
+ tf32 = None,
170
+ local_rank = -1,
171
+ ddp_backend = None,
172
+ tpu_num_cores = None,
173
+ tpu_metrics_debug = False,
174
+ debug = '',
175
+ dataloader_drop_last = False,
176
+ eval_steps = None,
177
+ dataloader_num_workers = 0,
178
+ dataloader_prefetch_factor = None,
179
+ past_index = -1,
180
+ run_name = None,
181
+ disable_tqdm = None,
182
+ remove_unused_columns = True,
183
+ label_names = None,
184
+ load_best_model_at_end = False,
185
+ metric_for_best_model = None,
186
+ greater_is_better = None,
187
+ ignore_data_skip = False,
188
+ fsdp = '',
189
+ fsdp_min_num_params = 0,
190
+ fsdp_config = None,
191
+ tp_size = 0,
192
+ fsdp_transformer_layer_cls_to_wrap = None,
193
+ accelerator_config = None,
194
+ deepspeed = None,
195
+ label_smoothing_factor = 0.0,
196
+ optim = 'adamw_8bit',
197
+ optim_args = None,
198
+ adafactor = False,
199
+ group_by_length = False,
200
+ length_column_name = 'length',
201
+ report_to = None,
202
+ ddp_find_unused_parameters = None,
203
+ ddp_bucket_cap_mb = None,
204
+ ddp_broadcast_buffers = None,
205
+ dataloader_pin_memory = True,
206
+ dataloader_persistent_workers = False,
207
+ skip_memory_metrics = True,
208
+ use_legacy_prediction_loop = False,
209
+ push_to_hub = False,
210
+ resume_from_checkpoint = None,
211
+ hub_model_id = None,
212
+ hub_strategy = 'every_save',
213
+ hub_token = None,
214
+ hub_private_repo = None,
215
+ hub_always_push = False,
216
+ gradient_checkpointing = False,
217
+ gradient_checkpointing_kwargs = None,
218
+ include_inputs_for_metrics = False,
219
+ eval_do_concat_batches = True,
220
+ fp16_backend = 'auto',
221
+ push_to_hub_model_id = None,
222
+ push_to_hub_organization = None,
223
+ push_to_hub_token = None,
224
+ mp_parameters = '',
225
+ auto_find_batch_size = False,
226
+ full_determinism = False,
227
+ torchdynamo = None,
228
+ ray_scope = 'last',
229
+ ddp_timeout = 1800,
230
+ torch_compile = False,
231
+ torch_compile_backend = None,
232
+ torch_compile_mode = None,
233
+ include_tokens_per_second = False,
234
+ include_num_input_tokens_seen = False,
235
+ neftune_noise_alpha = None,
236
+ optim_target_modules = None,
237
+ batch_eval_metrics = False,
238
+ eval_on_start = False,
239
+ use_liger_kernel = False,
240
+ eval_use_gather_object = False,
241
+ average_tokens_across_devices = False,
242
+ max_length = 1024,
243
+ max_prompt_length = 512,
244
+ max_completion_length = None,
245
+ beta = 0.1,
246
+ loss_type = 'kto',
247
+ desirable_weight = 1.0,
248
+ undesirable_weight = 1.0,
249
+ label_pad_token_id = -100,
250
+ padding_value = None,
251
+ truncation_mode = 'keep_end',
252
+ generate_during_eval = False,
253
+ is_encoder_decoder = None,
254
+ disable_dropout = True,
255
+ precompute_ref_log_probs = False,
256
+ model_init_kwargs = None,
257
+ ref_model_init_kwargs = None,
258
+ dataset_num_proc = None,
259
+ vllm_sampling_params = None,
260
+ unsloth_num_chunks = -1,
261
+ **kwargs,
262
+ ):
263
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
264
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
265
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
266
+ output_dir = 'unsloth_training_checkpoints'
267
+ save_strategy = 'no'
268
+ if dataset_num_proc is None:
269
+ from multiprocessing import cpu_count
270
+ dataset_num_proc = cpu_count()
271
+
272
+ super().__init__(
273
+ output_dir = output_dir,
274
+ overwrite_output_dir = overwrite_output_dir,
275
+ do_train = do_train,
276
+ do_eval = do_eval,
277
+ do_predict = do_predict,
278
+ eval_strategy = eval_strategy,
279
+ prediction_loss_only = prediction_loss_only,
280
+ per_device_train_batch_size = per_device_train_batch_size,
281
+ per_device_eval_batch_size = per_device_eval_batch_size,
282
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
283
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
284
+ gradient_accumulation_steps = gradient_accumulation_steps,
285
+ eval_accumulation_steps = eval_accumulation_steps,
286
+ eval_delay = eval_delay,
287
+ torch_empty_cache_steps = torch_empty_cache_steps,
288
+ learning_rate = learning_rate,
289
+ weight_decay = weight_decay,
290
+ adam_beta1 = adam_beta1,
291
+ adam_beta2 = adam_beta2,
292
+ adam_epsilon = adam_epsilon,
293
+ max_grad_norm = max_grad_norm,
294
+ num_train_epochs = num_train_epochs,
295
+ max_steps = max_steps,
296
+ lr_scheduler_type = lr_scheduler_type,
297
+ warmup_ratio = warmup_ratio,
298
+ warmup_steps = warmup_steps,
299
+ log_level = log_level,
300
+ log_level_replica = log_level_replica,
301
+ log_on_each_node = log_on_each_node,
302
+ logging_dir = logging_dir,
303
+ logging_strategy = logging_strategy,
304
+ logging_first_step = logging_first_step,
305
+ logging_steps = logging_steps,
306
+ logging_nan_inf_filter = logging_nan_inf_filter,
307
+ save_strategy = save_strategy,
308
+ save_steps = save_steps,
309
+ save_total_limit = save_total_limit,
310
+ save_safetensors = save_safetensors,
311
+ save_on_each_node = save_on_each_node,
312
+ save_only_model = save_only_model,
313
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
314
+ no_cuda = no_cuda,
315
+ use_cpu = use_cpu,
316
+ use_mps_device = use_mps_device,
317
+ seed = seed,
318
+ data_seed = data_seed,
319
+ jit_mode_eval = jit_mode_eval,
320
+ use_ipex = use_ipex,
321
+ bf16 = bf16,
322
+ fp16 = fp16,
323
+ fp16_opt_level = fp16_opt_level,
324
+ half_precision_backend = half_precision_backend,
325
+ bf16_full_eval = bf16_full_eval,
326
+ fp16_full_eval = fp16_full_eval,
327
+ tf32 = tf32,
328
+ local_rank = local_rank,
329
+ ddp_backend = ddp_backend,
330
+ tpu_num_cores = tpu_num_cores,
331
+ tpu_metrics_debug = tpu_metrics_debug,
332
+ debug = debug,
333
+ dataloader_drop_last = dataloader_drop_last,
334
+ eval_steps = eval_steps,
335
+ dataloader_num_workers = dataloader_num_workers,
336
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
337
+ past_index = past_index,
338
+ run_name = run_name,
339
+ disable_tqdm = disable_tqdm,
340
+ remove_unused_columns = remove_unused_columns,
341
+ label_names = label_names,
342
+ load_best_model_at_end = load_best_model_at_end,
343
+ metric_for_best_model = metric_for_best_model,
344
+ greater_is_better = greater_is_better,
345
+ ignore_data_skip = ignore_data_skip,
346
+ fsdp = fsdp,
347
+ fsdp_min_num_params = fsdp_min_num_params,
348
+ fsdp_config = fsdp_config,
349
+ tp_size = tp_size,
350
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
351
+ accelerator_config = accelerator_config,
352
+ deepspeed = deepspeed,
353
+ label_smoothing_factor = label_smoothing_factor,
354
+ optim = optim,
355
+ optim_args = optim_args,
356
+ adafactor = adafactor,
357
+ group_by_length = group_by_length,
358
+ length_column_name = length_column_name,
359
+ report_to = report_to,
360
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
361
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
362
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
363
+ dataloader_pin_memory = dataloader_pin_memory,
364
+ dataloader_persistent_workers = dataloader_persistent_workers,
365
+ skip_memory_metrics = skip_memory_metrics,
366
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
367
+ push_to_hub = push_to_hub,
368
+ resume_from_checkpoint = resume_from_checkpoint,
369
+ hub_model_id = hub_model_id,
370
+ hub_strategy = hub_strategy,
371
+ hub_token = hub_token,
372
+ hub_private_repo = hub_private_repo,
373
+ hub_always_push = hub_always_push,
374
+ gradient_checkpointing = gradient_checkpointing,
375
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
376
+ include_inputs_for_metrics = include_inputs_for_metrics,
377
+ eval_do_concat_batches = eval_do_concat_batches,
378
+ fp16_backend = fp16_backend,
379
+ push_to_hub_model_id = push_to_hub_model_id,
380
+ push_to_hub_organization = push_to_hub_organization,
381
+ push_to_hub_token = push_to_hub_token,
382
+ mp_parameters = mp_parameters,
383
+ auto_find_batch_size = auto_find_batch_size,
384
+ full_determinism = full_determinism,
385
+ torchdynamo = torchdynamo,
386
+ ray_scope = ray_scope,
387
+ ddp_timeout = ddp_timeout,
388
+ torch_compile = torch_compile,
389
+ torch_compile_backend = torch_compile_backend,
390
+ torch_compile_mode = torch_compile_mode,
391
+ include_tokens_per_second = include_tokens_per_second,
392
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
393
+ neftune_noise_alpha = neftune_noise_alpha,
394
+ optim_target_modules = optim_target_modules,
395
+ batch_eval_metrics = batch_eval_metrics,
396
+ eval_on_start = eval_on_start,
397
+ use_liger_kernel = use_liger_kernel,
398
+ eval_use_gather_object = eval_use_gather_object,
399
+ average_tokens_across_devices = average_tokens_across_devices,
400
+ max_length = max_length,
401
+ max_prompt_length = max_prompt_length,
402
+ max_completion_length = max_completion_length,
403
+ beta = beta,
404
+ loss_type = loss_type,
405
+ desirable_weight = desirable_weight,
406
+ undesirable_weight = undesirable_weight,
407
+ label_pad_token_id = label_pad_token_id,
408
+ padding_value = padding_value,
409
+ truncation_mode = truncation_mode,
410
+ generate_during_eval = generate_during_eval,
411
+ is_encoder_decoder = is_encoder_decoder,
412
+ disable_dropout = disable_dropout,
413
+ precompute_ref_log_probs = precompute_ref_log_probs,
414
+ model_init_kwargs = model_init_kwargs,
415
+ ref_model_init_kwargs = ref_model_init_kwargs,
416
+ dataset_num_proc = dataset_num_proc,**kwargs)
417
+ self.vllm_sampling_params = vllm_sampling_params
418
+ self.unsloth_num_chunks = unsloth_num_chunks
419
+ pass
420
+
421
+ class _UnslothKTOTrainer(Trainer):
422
+ r""""""
423
+
424
+ _tag_names = ["trl", "kto"]
425
+
426
+ def __init__(
427
+ self,
428
+ model: Union[PreTrainedModel, nn.Module, str] = None,
429
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
430
+ args: KTOConfig = None,
431
+ train_dataset: Optional[Dataset] = None,
432
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
433
+ processing_class: Optional[
434
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
435
+ ] = None,
436
+ data_collator: Optional[DataCollator] = None,
437
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
438
+ callbacks: Optional[list[TrainerCallback]] = None,
439
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
440
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
441
+ peft_config: Optional[dict] = None,
442
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
443
+ model_adapter_name: Optional[str] = None,
444
+ ref_adapter_name: Optional[str] = None,
445
+ ):
446
+ if type(args) is TrainingArguments:
447
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
448
+
449
+ if not isinstance(model, str) and ref_model is model:
450
+ raise ValueError(
451
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
452
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
453
+ )
454
+
455
+ if args.model_init_kwargs is None:
456
+ model_init_kwargs = {}
457
+ elif not isinstance(model, str):
458
+ raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
459
+ else:
460
+ model_init_kwargs = args.model_init_kwargs
461
+ torch_dtype = model_init_kwargs.get("torch_dtype")
462
+ if torch_dtype is not None:
463
+ # Convert to `torch.dtype` if an str is passed
464
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
465
+ torch_dtype = getattr(torch, torch_dtype)
466
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
467
+ raise ValueError(
468
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
469
+ )
470
+ model_init_kwargs["torch_dtype"] = torch_dtype
471
+
472
+ if args.ref_model_init_kwargs is None:
473
+ ref_model_init_kwargs = {}
474
+ elif not isinstance(ref_model, str):
475
+ raise ValueError(
476
+ "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
477
+ )
478
+ else:
479
+ ref_model_init_kwargs = args.ref_model_init_kwargs
480
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
481
+ if torch_dtype is not None:
482
+ # Convert to `torch.dtype` if an str is passed
483
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
484
+ torch_dtype = getattr(torch, torch_dtype)
485
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
486
+ raise ValueError(
487
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
488
+ )
489
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
490
+
491
+ if isinstance(model, str):
492
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
493
+
494
+ if isinstance(ref_model, str):
495
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
496
+
497
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
498
+ # has been called in order to properly call autocast if needed.
499
+ self._peft_has_been_casted_to_bf16 = False
500
+
501
+ if not is_peft_available() and peft_config is not None:
502
+ raise ValueError(
503
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
504
+ )
505
+ elif is_peft_available() and peft_config is not None:
506
+ # if model is a peft model and we have a peft_config, we merge and unload it first
507
+ if isinstance(model, PeftModel):
508
+ model = model.merge_and_unload()
509
+
510
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
511
+ _support_gc_kwargs = hasattr(
512
+ args, "gradient_checkpointing_kwargs"
513
+ ) and "gradient_checkpointing_kwargs" in list(
514
+ inspect.signature(prepare_model_for_kbit_training).parameters
515
+ )
516
+
517
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
518
+
519
+ if _support_gc_kwargs:
520
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
521
+
522
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
523
+ elif getattr(args, "gradient_checkpointing", False):
524
+ # For backward compatibility with older versions of transformers
525
+ if hasattr(model, "enable_input_require_grads"):
526
+ model.enable_input_require_grads()
527
+ else:
528
+
529
+ def make_inputs_require_grad(module, input, output):
530
+ output.requires_grad_(True)
531
+
532
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
533
+
534
+ # get peft model with the given config
535
+ model = model
536
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
537
+ peft_module_casting_to_bf16(model)
538
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
539
+ self._peft_has_been_casted_to_bf16 = True
540
+
541
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
542
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
543
+ # fail or completely fail.
544
+ elif getattr(args, "gradient_checkpointing", False):
545
+ # For backward compatibility with older versions of transformers
546
+ if hasattr(model, "enable_input_require_grads"):
547
+ model.enable_input_require_grads()
548
+ else:
549
+
550
+ def make_inputs_require_grad(module, input, output):
551
+ output.requires_grad_(True)
552
+
553
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
554
+
555
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
556
+ raise ValueError(
557
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
558
+ " Please install `wandb` or `comet-ml` to resolve."
559
+ )
560
+
561
+ if model is not None:
562
+ self.is_encoder_decoder = model.config.is_encoder_decoder
563
+ elif args.is_encoder_decoder is None:
564
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
565
+ else:
566
+ self.is_encoder_decoder = args.is_encoder_decoder
567
+
568
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
569
+ self.model_adapter_name = model_adapter_name
570
+ self.ref_adapter_name = ref_adapter_name
571
+
572
+ if ref_model:
573
+ self.ref_model = ref_model
574
+ elif self.is_peft_model or args.precompute_ref_log_probs:
575
+ # The `model` with adapters turned off will be used as the reference model
576
+ self.ref_model = None
577
+ else:
578
+ self.ref_model = create_reference_model(model)
579
+
580
+ if processing_class is None:
581
+ raise ValueError(
582
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
583
+ )
584
+ if args.max_length is None:
585
+ warnings.warn(
586
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
587
+ " it will be set to `512` by default, but you should do it yourself in the future.",
588
+ UserWarning,
589
+ )
590
+ max_length = 512
591
+ if args.max_length is not None:
592
+ max_length = args.max_length
593
+
594
+ if args.max_prompt_length is None:
595
+ warnings.warn(
596
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
597
+ " it will be set to `128` by default, but you should do it yourself in the future.",
598
+ UserWarning,
599
+ )
600
+ max_prompt_length = 128
601
+ if args.max_prompt_length is not None:
602
+ max_prompt_length = args.max_prompt_length
603
+
604
+ max_completion_length = None
605
+ if args.max_completion_length is None and self.is_encoder_decoder:
606
+ warnings.warn(
607
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
608
+ " it will be set to `128` by default, but you should do it yourself in the future.",
609
+ UserWarning,
610
+ )
611
+ max_completion_length = 128
612
+ if args.max_completion_length is not None and self.is_encoder_decoder:
613
+ max_completion_length = args.max_completion_length
614
+
615
+ if data_collator is None:
616
+ data_collator = DPODataCollatorWithPadding(
617
+ pad_token_id=processing_class.pad_token_id,
618
+ label_pad_token_id=args.label_pad_token_id,
619
+ is_encoder_decoder=self.is_encoder_decoder,
620
+ )
621
+
622
+ if args.remove_unused_columns:
623
+ args.remove_unused_columns = False
624
+ # warn users
625
+ warnings.warn(
626
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
627
+ " we have set it for you, but you should do it yourself in the future.",
628
+ UserWarning,
629
+ )
630
+
631
+ self.use_dpo_data_collator = True
632
+ else:
633
+ self.use_dpo_data_collator = False
634
+
635
+ # Disable dropout in the model and reference model
636
+ if args.disable_dropout:
637
+ disable_dropout_in_model(model)
638
+ if self.ref_model is not None:
639
+ disable_dropout_in_model(self.ref_model)
640
+
641
+ self.loss_type = args.loss_type
642
+ self.max_length = max_length
643
+ self.generate_during_eval = args.generate_during_eval
644
+ self.label_pad_token_id = args.label_pad_token_id
645
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
646
+ self.max_prompt_length = max_prompt_length
647
+ self.truncation_mode = args.truncation_mode
648
+ self.max_completion_length = max_completion_length
649
+ self.processing_class = processing_class
650
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
651
+
652
+ # Not all losses require a KL calculation
653
+ self.calculate_KL = True
654
+ if self.loss_type in ["apo_zero_unpaired"]:
655
+ self.calculate_KL = False
656
+
657
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
658
+ # keep track of first called to avoid computation of future calls
659
+ self._precomputed_train_ref_log_probs = False
660
+ self._precomputed_eval_ref_log_probs = False
661
+
662
+ # metric
663
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
664
+
665
+ # KTO parameter
666
+ self.beta = args.beta
667
+ self.desirable_weight = args.desirable_weight
668
+ self.undesirable_weight = args.undesirable_weight
669
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
670
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
671
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
672
+ warnings.warn(
673
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
674
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
675
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
676
+ "loss.",
677
+ UserWarning,
678
+ )
679
+
680
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
681
+ # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
682
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
683
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
684
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
685
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
686
+ # issued.
687
+ model.warnings_issued["estimate_tokens"] = True
688
+
689
+ # Compute that only on the main process for faster data processing.
690
+ # see: https://github.com/huggingface/trl/pull/1255
691
+ with PartialState().local_main_process_first():
692
+ # Extract the prompt if needed
693
+ train_dataset = train_dataset.map(
694
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
695
+ )
696
+ # Unpair the dataset if needed
697
+ train_dataset = maybe_unpair_preference_dataset(
698
+ train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
699
+ )
700
+ # Apply the chat template if needed
701
+ train_dataset = train_dataset.map(
702
+ maybe_apply_chat_template,
703
+ fn_kwargs={"tokenizer": processing_class},
704
+ num_proc=args.dataset_num_proc,
705
+ desc="Applying chat template to train dataset",
706
+ )
707
+ if eval_dataset is not None:
708
+ eval_dataset = eval_dataset.map(
709
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
710
+ )
711
+ eval_dataset = maybe_unpair_preference_dataset(
712
+ eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
713
+ )
714
+ eval_dataset = eval_dataset.map(
715
+ maybe_apply_chat_template,
716
+ fn_kwargs={"tokenizer": processing_class},
717
+ num_proc=args.dataset_num_proc,
718
+ desc="Applying chat template to eval dataset",
719
+ )
720
+
721
+ # Tokenize and prepare the training datasets
722
+ train_dataset = train_dataset.map(
723
+ _tokenize,
724
+ batched=True,
725
+ fn_kwargs={"tokenizer": self.processing_class},
726
+ num_proc=args.dataset_num_proc,
727
+ desc="Tokenizing train dataset",
728
+ )
729
+
730
+ fn_kwargs = {
731
+ "prefix": "",
732
+ "is_encoder_decoder": self.is_encoder_decoder,
733
+ "tokenizer": self.processing_class,
734
+ "max_length": self.max_length,
735
+ "truncation_mode": self.truncation_mode,
736
+ "label_pad_token_id": self.label_pad_token_id,
737
+ "max_prompt_length": self.max_prompt_length,
738
+ "max_completion_length": self.max_completion_length,
739
+ }
740
+
741
+ train_dataset = train_dataset.map(
742
+ _process_tokens,
743
+ fn_kwargs=fn_kwargs,
744
+ num_proc=args.dataset_num_proc,
745
+ desc="Processing tokenized train dataset",
746
+ )
747
+
748
+ # Tokenize and prepare the eval datasets
749
+ if eval_dataset is not None:
750
+ eval_dataset = eval_dataset.map(
751
+ _tokenize,
752
+ fn_kwargs={"tokenizer": self.processing_class},
753
+ batched=True,
754
+ num_proc=args.dataset_num_proc,
755
+ desc="Tokenizing eval dataset",
756
+ )
757
+
758
+ eval_dataset = eval_dataset.map(
759
+ _process_tokens,
760
+ fn_kwargs=fn_kwargs,
761
+ num_proc=args.dataset_num_proc,
762
+ desc="Processing tokenized eval dataset",
763
+ )
764
+
765
+ # Get KL datasets if needed
766
+ if self.calculate_KL:
767
+ if args.per_device_train_batch_size <= 1:
768
+ raise ValueError(
769
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
770
+ )
771
+
772
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
773
+ # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
774
+ train_kl_dataset = train_dataset.map(
775
+ _get_kl_dataset,
776
+ batched=True,
777
+ batch_size=args.per_device_train_batch_size,
778
+ num_proc=args.dataset_num_proc,
779
+ desc="Extracting KL train dataset",
780
+ )
781
+
782
+ fn_kwargs["prefix"] = "KL_"
783
+ train_kl_dataset = train_kl_dataset.map(
784
+ _process_tokens,
785
+ fn_kwargs=fn_kwargs,
786
+ num_proc=args.dataset_num_proc,
787
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
788
+ desc="Processing tokenized train KL dataset",
789
+ )
790
+
791
+ # merge the datasets
792
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
793
+
794
+ if eval_dataset is not None:
795
+ # Get KL dataset
796
+ eval_kl_dataset = eval_dataset.map(
797
+ _get_kl_dataset,
798
+ batched=True,
799
+ batch_size=args.per_device_train_batch_size,
800
+ num_proc=args.dataset_num_proc,
801
+ desc="Extracting eval KL dataset",
802
+ )
803
+
804
+ eval_kl_dataset = eval_kl_dataset.map(
805
+ _process_tokens,
806
+ fn_kwargs=fn_kwargs,
807
+ num_proc=args.dataset_num_proc,
808
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
809
+ desc="Processing tokenized eval KL dataset",
810
+ )
811
+
812
+ # merge the datasets
813
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
814
+
815
+ # calculate dataset desirability balance
816
+ num_desirable = max(sum(train_dataset["label"]), 1)
817
+ num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
818
+
819
+ if num_desirable != num_undesirable:
820
+ # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
821
+ des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
822
+ des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
823
+ und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
824
+ und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
825
+
826
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
827
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
828
+
829
+ if not (des_weight_in_range or und_weight_in_range):
830
+ warnings.warn(
831
+ "You have different amounts of desirable/positive and undesirable/negative examples but the "
832
+ "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
833
+ f"on your data, we recommend EITHER "
834
+ f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
835
+ f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
836
+ "See the documentation on how to optimally set these weights.",
837
+ UserWarning,
838
+ )
839
+
840
+ super().__init__(
841
+ model=model,
842
+ args=args,
843
+ data_collator=data_collator,
844
+ train_dataset=train_dataset,
845
+ eval_dataset=eval_dataset,
846
+ processing_class=processing_class,
847
+ model_init=model_init,
848
+ compute_metrics=compute_metrics,
849
+ callbacks=callbacks,
850
+ optimizers=optimizers,
851
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
852
+ )
853
+
854
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
855
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
856
+ # self.model_accepts_loss_kwargs to False to enable scaling.
857
+ self.model_accepts_loss_kwargs = False
858
+
859
+ # Add tags for models that have been loaded with the correct transformers version
860
+ if hasattr(self.model, "add_model_tags"):
861
+ self.model.add_model_tags(self._tag_names)
862
+
863
+ if not hasattr(self, "accelerator"):
864
+ raise AttributeError(
865
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
866
+ )
867
+
868
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
869
+ if self.is_deepspeed_enabled:
870
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
871
+ raise ValueError(
872
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
873
+ )
874
+
875
+ if self.ref_model is None:
876
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
877
+ raise ValueError(
878
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
879
+ )
880
+ else:
881
+ if self.is_deepspeed_enabled:
882
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
883
+ else:
884
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
885
+
886
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
887
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
888
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
889
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
890
+
891
+ if model is not None:
892
+ if hasattr(model, "config"):
893
+ hidden_size = (
894
+ max(model.config.hidden_sizes)
895
+ if getattr(model.config, "hidden_sizes", None)
896
+ else getattr(model.config, "hidden_size", None)
897
+ )
898
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
899
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
900
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
901
+ config_kwargs.update(
902
+ {
903
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
904
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
905
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
906
+ }
907
+ )
908
+
909
+ # If ZeRO-3 is used, we shard both the active and reference model.
910
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
911
+ if config_kwargs["zero_optimization"]["stage"] != 3:
912
+ config_kwargs["zero_optimization"]["stage"] = 0
913
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
914
+ model.eval()
915
+ return model
916
+
917
+ @contextmanager
918
+ def null_ref_context(self):
919
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
920
+ with (
921
+ self.accelerator.unwrap_model(self.model).disable_adapter()
922
+ if self.is_peft_model and not self.ref_adapter_name
923
+ else nullcontext()
924
+ ):
925
+ if self.ref_adapter_name:
926
+ self.model.set_adapter(self.ref_adapter_name)
927
+ yield
928
+ if self.ref_adapter_name:
929
+ self.model.set_adapter(self.model_adapter_name or "default")
930
+
931
+ def get_train_dataloader(self) -> DataLoader:
932
+ """
933
+ Returns the training [`~torch.utils.data.DataLoader`].
934
+
935
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
936
+ """
937
+
938
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
939
+ dataloader_params = {
940
+ "batch_size": self.args.per_device_train_batch_size,
941
+ "collate_fn": self.data_collator,
942
+ "num_workers": self.args.dataloader_num_workers,
943
+ "pin_memory": self.args.dataloader_pin_memory,
944
+ "shuffle": False,
945
+ }
946
+
947
+ # prepare dataloader
948
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
949
+ reference_completion_logps = []
950
+ reference_KL_logps = []
951
+
952
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
953
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
954
+
955
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
956
+ reference_completion_logps.append(reference_completion_logp.cpu())
957
+
958
+ if self.calculate_KL:
959
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
960
+ reference_KL_logps.append(reference_KL_logp.cpu())
961
+
962
+ self.train_dataset = self.train_dataset.add_column(
963
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
964
+ )
965
+
966
+ if self.calculate_KL:
967
+ self.train_dataset = self.train_dataset.add_column(
968
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
969
+ )
970
+
971
+ self._precomputed_train_ref_log_probs = True
972
+
973
+ return super().get_train_dataloader()
974
+
975
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
976
+ """
977
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
978
+
979
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
980
+
981
+ Args:
982
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
983
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
984
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
985
+ """
986
+ if eval_dataset is None and self.eval_dataset is None:
987
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
988
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
989
+
990
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
991
+ dataloader_params = {
992
+ "batch_size": self.args.per_device_eval_batch_size,
993
+ "collate_fn": self.data_collator,
994
+ "num_workers": self.args.dataloader_num_workers,
995
+ "pin_memory": self.args.dataloader_pin_memory,
996
+ "shuffle": False,
997
+ }
998
+
999
+ # prepare dataloader
1000
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1001
+
1002
+ reference_completion_logps = []
1003
+ reference_KL_logps = []
1004
+
1005
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1006
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1007
+
1008
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1009
+ reference_completion_logps.append(reference_completion_logp.cpu())
1010
+
1011
+ if self.calculate_KL:
1012
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1013
+ reference_KL_logps.append(reference_KL_logp.cpu())
1014
+
1015
+ eval_dataset = eval_dataset.add_column(
1016
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1017
+ )
1018
+ if self.calculate_KL:
1019
+ eval_dataset = eval_dataset.add_column(
1020
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1021
+ )
1022
+
1023
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1024
+ if self.eval_dataset is not None:
1025
+ self.eval_dataset = eval_dataset
1026
+ self._precomputed_eval_ref_log_probs = True
1027
+
1028
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1029
+
1030
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1031
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1032
+ with torch.no_grad():
1033
+ if self.ref_model is None:
1034
+ with self.null_ref_context():
1035
+ if self.is_encoder_decoder:
1036
+ completion_logits = self.model(
1037
+ padded_batch["prompt_input_ids"],
1038
+ attention_mask=padded_batch["prompt_attention_mask"],
1039
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1040
+ labels=padded_batch["completion_labels"],
1041
+ ).logits
1042
+
1043
+ if self.calculate_KL:
1044
+ KL_logits = self.model(
1045
+ padded_batch["KL_prompt_input_ids"],
1046
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1047
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1048
+ labels=padded_batch["KL_completion_labels"],
1049
+ ).logits
1050
+ else:
1051
+ completion_logits = self.model(
1052
+ padded_batch["completion_input_ids"],
1053
+ attention_mask=padded_batch["completion_attention_mask"],
1054
+ ).logits
1055
+
1056
+ if self.calculate_KL:
1057
+ KL_logits = self.model(
1058
+ padded_batch["KL_completion_input_ids"],
1059
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1060
+ ).logits
1061
+ else:
1062
+ if self.is_encoder_decoder:
1063
+ completion_logits = self.ref_model(
1064
+ padded_batch["prompt_input_ids"],
1065
+ attention_mask=padded_batch["prompt_attention_mask"],
1066
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1067
+ labels=padded_batch["completion_labels"],
1068
+ ).logits
1069
+
1070
+ if self.calculate_KL:
1071
+ KL_logits = self.ref_model(
1072
+ padded_batch["KL_prompt_input_ids"],
1073
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1074
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1075
+ labels=padded_batch["KL_completion_labels"],
1076
+ ).logits
1077
+ else:
1078
+ completion_logits = self.ref_model(
1079
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1080
+ ).logits
1081
+
1082
+ if self.calculate_KL:
1083
+ KL_logits = self.ref_model(
1084
+ padded_batch["KL_completion_input_ids"],
1085
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1086
+ ).logits
1087
+
1088
+ completion_logps = self.get_batch_logps(
1089
+ completion_logits,
1090
+ padded_batch["completion_labels"],
1091
+ average_log_prob=False,
1092
+ is_encoder_decoder=self.is_encoder_decoder,
1093
+ label_pad_token_id=self.label_pad_token_id,
1094
+ )
1095
+
1096
+ if self.calculate_KL:
1097
+ KL_logps = self.get_batch_logps(
1098
+ KL_logits,
1099
+ padded_batch["KL_completion_labels"],
1100
+ average_log_prob=False,
1101
+ is_encoder_decoder=self.is_encoder_decoder,
1102
+ label_pad_token_id=self.label_pad_token_id,
1103
+ )
1104
+ else:
1105
+ KL_logps = None
1106
+
1107
+ return completion_logps, KL_logps
1108
+
1109
+ @staticmethod
1110
+ def get_batch_logps(
1111
+ logits: torch.FloatTensor,
1112
+ labels: torch.LongTensor,
1113
+ average_log_prob: bool = False,
1114
+ label_pad_token_id: int = -100,
1115
+ is_encoder_decoder: bool = False,
1116
+ ) -> torch.FloatTensor:
1117
+ """Compute the log probabilities of the given labels under the given logits.
1118
+
1119
+ Args:
1120
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1121
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1122
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1123
+
1124
+ Returns:
1125
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1126
+ """
1127
+ if logits.shape[:-1] != labels.shape:
1128
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1129
+
1130
+ if not is_encoder_decoder:
1131
+ labels = labels[:, 1:].clone()
1132
+ logits = logits[:, :-1, :]
1133
+ else:
1134
+ # Fixes end-dec RuntimeError
1135
+ labels = labels.clone()
1136
+
1137
+ loss_mask = labels != label_pad_token_id
1138
+
1139
+ # dummy token; we'll ignore the losses on these tokens later
1140
+ labels[labels == label_pad_token_id] = 0
1141
+
1142
+ per_token_logps = selective_log_softmax(logits, labels)
1143
+
1144
+ if average_log_prob:
1145
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1146
+ else:
1147
+ return (per_token_logps * loss_mask).sum(-1)
1148
+
1149
+ def forward(
1150
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1151
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1152
+ if self.calculate_KL:
1153
+ KL_logps = None
1154
+ KL_model_kwargs = (
1155
+ {
1156
+ "input_ids": batch["KL_prompt_input_ids"],
1157
+ "attention_mask": batch["KL_prompt_attention_mask"],
1158
+ "labels": batch["KL_completion_labels"],
1159
+ "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1160
+ }
1161
+ if self.is_encoder_decoder
1162
+ else {
1163
+ "input_ids": batch["KL_completion_input_ids"],
1164
+ "attention_mask": batch["KL_completion_attention_mask"],
1165
+ }
1166
+ )
1167
+ with torch.no_grad():
1168
+ KL_logits = model(
1169
+ **KL_model_kwargs,
1170
+ ).logits
1171
+
1172
+ KL_logps = self.get_batch_logps(
1173
+ KL_logits,
1174
+ batch["KL_completion_labels"],
1175
+ average_log_prob=False,
1176
+ is_encoder_decoder=self.is_encoder_decoder,
1177
+ label_pad_token_id=self.label_pad_token_id,
1178
+ )
1179
+ else:
1180
+ KL_logps = None
1181
+
1182
+ model_kwargs = (
1183
+ {
1184
+ "labels": batch["completion_labels"],
1185
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1186
+ }
1187
+ if self.is_encoder_decoder
1188
+ else {}
1189
+ )
1190
+ if self.aux_loss_enabled:
1191
+ model_kwargs["output_router_logits"] = True
1192
+
1193
+ outputs = model(
1194
+ batch["completion_input_ids"],
1195
+ attention_mask=batch["completion_attention_mask"],
1196
+ **model_kwargs,
1197
+ )
1198
+ completion_logits = outputs.logits
1199
+
1200
+ completion_logps = self.get_batch_logps(
1201
+ completion_logits,
1202
+ batch["completion_labels"],
1203
+ average_log_prob=False,
1204
+ is_encoder_decoder=self.is_encoder_decoder,
1205
+ label_pad_token_id=self.label_pad_token_id,
1206
+ )
1207
+
1208
+ if completion_logps.shape[0] != len(batch["label"]):
1209
+ raise ValueError(
1210
+ "There is a mismatch between the number of examples in this batch and the number of "
1211
+ "examples for which an output sequence was predicted."
1212
+ )
1213
+
1214
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1215
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1216
+
1217
+ chosen_logps = completion_logps[chosen_idx, ...]
1218
+ rejected_logps = completion_logps[rejected_idx, ...]
1219
+
1220
+ chosen_logits = completion_logits[chosen_idx, ...]
1221
+ rejected_logits = completion_logits[rejected_idx, ...]
1222
+
1223
+ if self.aux_loss_enabled:
1224
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1225
+ else:
1226
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1227
+
1228
+ def kto_loss(
1229
+ self,
1230
+ policy_chosen_logps: torch.FloatTensor,
1231
+ policy_rejected_logps: torch.FloatTensor,
1232
+ policy_KL_logps: torch.FloatTensor,
1233
+ reference_chosen_logps: torch.FloatTensor,
1234
+ reference_rejected_logps: torch.FloatTensor,
1235
+ reference_KL_logps: torch.FloatTensor,
1236
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1237
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
1238
+
1239
+ Args:
1240
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1241
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1242
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1243
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1246
+
1247
+ Returns:
1248
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1249
+ The losses tensor contains the KTO loss for each example in the batch.
1250
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1251
+ The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1252
+ """
1253
+ if self.calculate_KL:
1254
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1255
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1256
+ else:
1257
+ kl = torch.zeros(1).to(policy_chosen_logps.device)
1258
+
1259
+ # Chosen losses
1260
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1261
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1262
+
1263
+ if self.loss_type == "kto":
1264
+ # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1265
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1266
+ elif self.loss_type == "apo_zero_unpaired":
1267
+ # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1268
+ # Use this loss when you believe the chosen outputs are better than your model's default output
1269
+ chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1270
+
1271
+ chosen_rewards = self.beta * chosen_logratios.detach()
1272
+
1273
+ else:
1274
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1275
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1276
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1277
+
1278
+ # Rejected losses
1279
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1280
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1281
+
1282
+ if self.loss_type == "kto":
1283
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1284
+ elif self.loss_type == "apo_zero_unpaired":
1285
+ rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1286
+
1287
+ rejected_rewards = self.beta * rejected_logratios.detach()
1288
+ else:
1289
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1290
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1291
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1292
+
1293
+ losses = torch.cat(
1294
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1295
+ 0,
1296
+ )
1297
+
1298
+ return losses, chosen_rewards, rejected_rewards, kl
1299
+
1300
+ def get_batch_loss_metrics(
1301
+ self,
1302
+ model,
1303
+ batch: dict[str, Union[list, torch.LongTensor]],
1304
+ ):
1305
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1306
+ metrics = {}
1307
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1308
+
1309
+ forward_output = self.forward(model, batch)
1310
+ (
1311
+ policy_chosen_logps,
1312
+ policy_rejected_logps,
1313
+ policy_chosen_logits,
1314
+ policy_rejected_logits,
1315
+ policy_KL_logps,
1316
+ ) = forward_output[:5]
1317
+ if self.aux_loss_enabled:
1318
+ aux_loss = forward_output[5]
1319
+
1320
+ # if reference_logps in batch use them, otherwise use the reference model
1321
+ if "reference_logps" in batch:
1322
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1323
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1324
+
1325
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1326
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1327
+ if self.calculate_KL:
1328
+ reference_KL_logps = batch["reference_KL_logps"]
1329
+ else:
1330
+ reference_KL_logps = None
1331
+ else:
1332
+ with torch.no_grad():
1333
+ if self.ref_model is None:
1334
+ with self.null_ref_context():
1335
+ (
1336
+ reference_chosen_logps,
1337
+ reference_rejected_logps,
1338
+ _,
1339
+ _,
1340
+ reference_KL_logps,
1341
+ ) = self.forward(self.model, batch)[:5]
1342
+ else:
1343
+ (
1344
+ reference_chosen_logps,
1345
+ reference_rejected_logps,
1346
+ _,
1347
+ _,
1348
+ reference_KL_logps,
1349
+ ) = self.forward(self.ref_model, batch)[:5]
1350
+
1351
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1352
+ policy_chosen_logps,
1353
+ policy_rejected_logps,
1354
+ policy_KL_logps,
1355
+ reference_chosen_logps,
1356
+ reference_rejected_logps,
1357
+ reference_KL_logps,
1358
+ )
1359
+ metrics["kl"] = kl.item()
1360
+
1361
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1362
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1363
+
1364
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1365
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1366
+
1367
+ if all_num_chosen > 0:
1368
+ metrics["rewards/chosen_sum"] = (
1369
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1370
+ )
1371
+ metrics["logps/chosen_sum"] = (
1372
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1373
+ )
1374
+ metrics["logits/chosen_sum"] = (
1375
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1376
+ )
1377
+ metrics["count/chosen"] = all_num_chosen
1378
+
1379
+ if all_num_rejected > 0:
1380
+ metrics["rewards/rejected_sum"] = (
1381
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1382
+ )
1383
+ metrics["logps/rejected_sum"] = (
1384
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1385
+ )
1386
+ metrics["logits/rejected_sum"] = (
1387
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1388
+ )
1389
+ metrics["count/rejected"] = all_num_rejected
1390
+
1391
+ loss = losses.nanmean()
1392
+ if self.aux_loss_enabled:
1393
+ loss += self.aux_loss_coef * aux_loss
1394
+
1395
+ return loss, metrics
1396
+
1397
+ def compute_loss(
1398
+ self,
1399
+ model: Union[PreTrainedModel, nn.Module],
1400
+ inputs: dict[str, Union[torch.Tensor, Any]],
1401
+ return_outputs=False,
1402
+ num_items_in_batch=None,
1403
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1404
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1405
+
1406
+ with compute_loss_context_manager:
1407
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1408
+
1409
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1410
+ loss = loss.to(self.args.device)
1411
+ # force log the metrics
1412
+ if self.accelerator.is_main_process:
1413
+ self.store_metrics(metrics, train_eval="train")
1414
+
1415
+ if return_outputs:
1416
+ return (loss, metrics)
1417
+ return loss
1418
+
1419
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1420
+ for key, value in metrics.items():
1421
+ self._stored_metrics[train_eval][key].append(value)
1422
+
1423
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1424
+ if self.train_dataset is None or not has_length(self.train_dataset):
1425
+ return None
1426
+ return SequentialSampler(self.train_dataset)
1427
+
1428
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1429
+ """Generate samples from the model and reference model for the given batch of inputs."""
1430
+
1431
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1432
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1433
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1434
+
1435
+ with generate_context_manager:
1436
+ policy_output = model.generate(
1437
+ input_ids=batch["prompt_input_ids"],
1438
+ attention_mask=batch["prompt_attention_mask"],
1439
+ max_length=self.max_length,
1440
+ do_sample=True,
1441
+ pad_token_id=self.processing_class.pad_token_id,
1442
+ )
1443
+
1444
+ # if reference_output in batch use that otherwise use the reference model
1445
+ if "reference_output" in batch:
1446
+ reference_output = batch["reference_output"]
1447
+ else:
1448
+ if self.ref_model is None:
1449
+ with self.null_ref_context():
1450
+ reference_output = self.model.generate(
1451
+ input_ids=batch["prompt_input_ids"],
1452
+ attention_mask=batch["prompt_attention_mask"],
1453
+ max_length=self.max_length,
1454
+ do_sample=True,
1455
+ pad_token_id=self.processing_class.pad_token_id,
1456
+ )
1457
+ else:
1458
+ reference_output = self.ref_model.generate(
1459
+ input_ids=batch["prompt_input_ids"],
1460
+ attention_mask=batch["prompt_attention_mask"],
1461
+ max_length=self.max_length,
1462
+ do_sample=True,
1463
+ pad_token_id=self.processing_class.pad_token_id,
1464
+ )
1465
+
1466
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1467
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1468
+
1469
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1470
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1471
+
1472
+ return policy_output_decoded, reference_output_decoded
1473
+
1474
+ def prediction_step(
1475
+ self,
1476
+ model: Union[PreTrainedModel, nn.Module],
1477
+ inputs: dict[str, Union[torch.Tensor, Any]],
1478
+ prediction_loss_only: bool,
1479
+ ignore_keys: Optional[list[str]] = None,
1480
+ ):
1481
+ if ignore_keys is None:
1482
+ if hasattr(model, "config"):
1483
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1484
+ else:
1485
+ ignore_keys = []
1486
+
1487
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1488
+ with torch.no_grad(), prediction_context_manager:
1489
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1490
+
1491
+ # force log the metrics
1492
+ if self.accelerator.is_main_process:
1493
+ self.store_metrics(metrics, train_eval="eval")
1494
+
1495
+ if prediction_loss_only:
1496
+ return (loss.detach(), None, None)
1497
+
1498
+ # logits for the chosen and rejected samples from model
1499
+ logits_dict = {
1500
+ "eval_logits/chosen": metrics["logits/chosen"],
1501
+ "eval_logits/rejected": metrics["logits/rejected"],
1502
+ }
1503
+ logits = torch.tensor(
1504
+ [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
1505
+ )
1506
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1507
+
1508
+ return (loss.detach(), logits, labels)
1509
+
1510
+ def evaluation_loop(
1511
+ self,
1512
+ dataloader: DataLoader,
1513
+ description: str,
1514
+ prediction_loss_only: Optional[bool] = None,
1515
+ ignore_keys: Optional[list[str]] = None,
1516
+ metric_key_prefix: str = "eval",
1517
+ ) -> EvalLoopOutput:
1518
+ """
1519
+ Overriding built-in evaluation loop to store metrics for each batch.
1520
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1521
+
1522
+ Works both with or without labels.
1523
+ """
1524
+
1525
+ # Sample and save to game log if requested (for one batch to save time)
1526
+ if self.generate_during_eval:
1527
+ # Generate random indices within the range of the total number of samples
1528
+ num_samples = len(dataloader.dataset)
1529
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1530
+
1531
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1532
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1533
+ random_batch = self.data_collator(random_batch_dataset)
1534
+ random_batch = self._prepare_inputs(random_batch)
1535
+
1536
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1537
+ target_batch = {
1538
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1539
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1540
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1541
+ }
1542
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1543
+
1544
+ table = pd.DataFrame(
1545
+ columns=["Prompt", "Policy", "Ref Model"],
1546
+ data=[
1547
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1548
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1549
+ ],
1550
+ )
1551
+ if "wandb" in self.args.report_to:
1552
+ wandb.log({"game_log": wandb.Table(data=table)})
1553
+
1554
+ if "comet_ml" in self.args.report_to:
1555
+ log_table_to_comet_experiment(
1556
+ name="game_log.csv",
1557
+ table=table,
1558
+ )
1559
+
1560
+ # Base evaluation
1561
+ initial_output = super().evaluation_loop(
1562
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1563
+ )
1564
+
1565
+ return initial_output
1566
+
1567
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1568
+ """
1569
+ Log `logs` on the various objects watching training, including stored metrics.
1570
+
1571
+ Args:
1572
+ logs (`dict[str, float]`):
1573
+ The values to log.
1574
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1575
+ Start time of the training.
1576
+ """
1577
+ # logs either has 'loss' or 'eval_loss'
1578
+ train_eval = "train" if "loss" in logs else "eval"
1579
+ # train metrics should have no prefix, eval should have 'eval_'
1580
+ prefix = "eval_" if train_eval == "eval" else ""
1581
+ # accumulate average metrics from sums and lengths
1582
+ for split in ["chosen", "rejected"]:
1583
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1584
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1585
+ for metric in ["rewards", "logps", "logits"]:
1586
+ logs[f"{prefix}{metric}/{split}"] = (
1587
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1588
+ / count_sum
1589
+ )
1590
+ # delete obsolete metric
1591
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1592
+ del self._stored_metrics[train_eval][f"count/{split}"]
1593
+ # calculate reward margin
1594
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1595
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1596
+ # Add averaged stored metrics to logs
1597
+ for key, metrics in self._stored_metrics[train_eval].items():
1598
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1599
+ del self._stored_metrics[train_eval]
1600
+
1601
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1602
+ return super().log(logs, start_time)
1603
+ else: # transformers<=4.46
1604
+ return super().log(logs)
1605
+
1606
+ def create_model_card(
1607
+ self,
1608
+ model_name: Optional[str] = None,
1609
+ dataset_name: Optional[str] = None,
1610
+ tags: Union[str, list[str], None] = None,
1611
+ ):
1612
+ """
1613
+ Creates a draft of a model card using the information available to the `Trainer`.
1614
+
1615
+ Args:
1616
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1617
+ Name of the model.
1618
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1619
+ Name of the dataset used for training.
1620
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1621
+ Tags to be associated with the model card.
1622
+ """
1623
+ if not self.is_world_process_zero():
1624
+ return
1625
+
1626
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1627
+ base_model = self.model.config._name_or_path
1628
+ else:
1629
+ base_model = None
1630
+
1631
+ tags = tags or []
1632
+ if isinstance(tags, str):
1633
+ tags = [tags]
1634
+
1635
+ if hasattr(self.model.config, "unsloth_version"):
1636
+ tags.append("unsloth")
1637
+
1638
+ citation = textwrap.dedent("""\
1639
+ @article{ethayarajh2024kto,
1640
+ title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1641
+ author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1642
+ year = 2024,
1643
+ eprint = {arXiv:2402.01306},
1644
+ }""")
1645
+
1646
+ model_card = generate_model_card(
1647
+ base_model=base_model,
1648
+ model_name=model_name,
1649
+ hub_model_id=self.hub_model_id,
1650
+ dataset_name=dataset_name,
1651
+ tags=tags,
1652
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1653
+ comet_url=get_comet_experiment_url(),
1654
+ trainer_name="KTO",
1655
+ trainer_citation=citation,
1656
+ paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1657
+ paper_id="2402.01306",
1658
+ )
1659
+
1660
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1661
+ class UnslothKTOTrainer(_UnslothKTOTrainer):
1662
+ """
1663
+
1664
+ Initialize KTOTrainer.
1665
+
1666
+ Args:
1667
+ model (`transformers.PreTrainedModel`):
1668
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1669
+ ref_model (`PreTrainedModelWrapper`):
1670
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1671
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1672
+ args (`KTOConfig`):
1673
+ The arguments to use for training.
1674
+ train_dataset (`datasets.Dataset`):
1675
+ The dataset to use for training.
1676
+ eval_dataset (`datasets.Dataset`):
1677
+ The dataset to use for evaluation.
1678
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1679
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1680
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1681
+ reuse the fine-tuned model.
1682
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1683
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1684
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1685
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1686
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1687
+ callbacks (`list[transformers.TrainerCallback]`):
1688
+ The callbacks to use for training.
1689
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1690
+ The optimizer and scheduler to use for training.
1691
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1692
+ The function to use to preprocess the logits before computing the metrics.
1693
+ peft_config (`dict`, defaults to `None`):
1694
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1695
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1696
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1697
+ a dictionary string to metric values.
1698
+ model_adapter_name (`str`, defaults to `None`):
1699
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1700
+ ref_adapter_name (`str`, defaults to `None`):
1701
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1702
+
1703
+ """
1704
+ def __init__(
1705
+ self,
1706
+ model = None,
1707
+ ref_model = None,
1708
+ args = None,
1709
+ train_dataset = None,
1710
+ eval_dataset = None,
1711
+ processing_class = None,
1712
+ data_collator = None,
1713
+ model_init = None,
1714
+ callbacks = None,
1715
+ preprocess_logits_for_metrics = None,
1716
+ peft_config = None,
1717
+ compute_metrics = None,
1718
+ model_adapter_name = None,
1719
+ ref_adapter_name = None,
1720
+ **kwargs
1721
+ ):
1722
+ if args is None: args = UnslothKTOConfig()
1723
+ use_bf16 = getattr(args, 'bf16', False)
1724
+ use_fp16 = getattr(args, 'fp16', False)
1725
+ force_float32 = False
1726
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1727
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1728
+ force_float32 = True
1729
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1730
+ dtype = getattr(model.config, 'torch_dtype', None)
1731
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1732
+ from unsloth_zoo.utils import _get_dtype
1733
+ dtype = _get_dtype(dtype)
1734
+ float16 = dtype == torch.float16
1735
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1736
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1737
+ if force_float32:
1738
+ args.fp16 = False
1739
+ args.bf16 = False
1740
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1741
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1742
+ args.fp16 = float16
1743
+ args.bf16 = not float16
1744
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1745
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1746
+ args.eval_strategy = 'steps'
1747
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1748
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1749
+ if ga_steps is not None and ga_steps > 1:
1750
+ from transformers import __version__ as transformers_version
1751
+ if Version(transformers_version) <= Version('4.45.2'):
1752
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1753
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1754
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1755
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1756
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1757
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1758
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1759
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1760
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1761
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1762
+ if force_float32:
1763
+ args.bf16_full_eval = False
1764
+ args.fp16_full_eval = False
1765
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1766
+ args.bf16_full_eval = True
1767
+ args.fp16_full_eval = False
1768
+ elif not bf16_full_eval and not fp16_full_eval:
1769
+ args.bf16_full_eval = args.bf16
1770
+ args.fp16_full_eval = args.fp16
1771
+ _output_logits = False
1772
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1773
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1774
+ if _output_logits:
1775
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1776
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1777
+ pass
1778
+ else:
1779
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1780
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1781
+ if args_max_seq_length is None and model_max_seq_length is not None:
1782
+ max_seq_length = model.max_seq_length
1783
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1784
+ if model is not None and hasattr(model, 'for_training'):
1785
+ model.for_training()
1786
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1787
+ if 'processing_class' in locals():
1788
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1789
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1790
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1791
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1792
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1793
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1795
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1796
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1797
+ else:
1798
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1799
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1800
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1801
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1802
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1803
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1804
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1805
+ else:
1806
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1807
+ other_metrics = []
1808
+
1809
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1810
+ PatchRLStatistics('kto_trainer', other_metrics)
1811
+
1812
+ super().__init__(
1813
+ model = model,
1814
+ ref_model = ref_model,
1815
+ args = args,
1816
+ train_dataset = train_dataset,
1817
+ eval_dataset = eval_dataset,
1818
+ processing_class = processing_class,
1819
+ data_collator = data_collator,
1820
+ model_init = model_init,
1821
+ callbacks = callbacks,
1822
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1823
+ peft_config = peft_config,
1824
+ compute_metrics = compute_metrics,
1825
+ model_adapter_name = model_adapter_name,
1826
+ ref_adapter_name = ref_adapter_name,**kwargs)
1827
+ if hasattr(self, 'neftune_hook_handle'):
1828
+ self.neftune_hook_handle.remove()
1829
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1830
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1831
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1832
+ pass
1833
+
1834
+ pass
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothNashMDConfig(NashMDConfig):
44
+ """
45
+
46
+ Configuration class for the [`NashMDTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
52
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
53
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
54
+ epochs.
55
+
56
+ """
57
+ vllm_sampling_params: Optional[Any] = field(
58
+ default = None,
59
+ metadata = {'help': 'vLLM SamplingParams'},
60
+ )
61
+ unsloth_num_chunks : Optional[int] = field(
62
+ default = -1,
63
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
64
+ )
65
+ def __init__(
66
+ self,
67
+ output_dir = None,
68
+ overwrite_output_dir = None,
69
+ do_train = False,
70
+ do_eval = False,
71
+ do_predict = False,
72
+ eval_strategy = 'no',
73
+ prediction_loss_only = False,
74
+ per_device_train_batch_size = 4,
75
+ per_device_eval_batch_size = 4,
76
+ per_gpu_train_batch_size = None,
77
+ per_gpu_eval_batch_size = None,
78
+ gradient_accumulation_steps = 2,
79
+ eval_accumulation_steps = 2,
80
+ eval_delay = 0,
81
+ torch_empty_cache_steps = 250,
82
+ learning_rate = 5e-05,
83
+ weight_decay = 0.01,
84
+ adam_beta1 = 0.9,
85
+ adam_beta2 = 0.999,
86
+ adam_epsilon = 1e-08,
87
+ max_grad_norm = 1.0,
88
+ num_train_epochs = 3.0,
89
+ max_steps = -1,
90
+ lr_scheduler_type = 'linear',
91
+ warmup_ratio = 0.1,
92
+ warmup_steps = 0,
93
+ log_level = 'passive',
94
+ log_level_replica = 'warning',
95
+ log_on_each_node = True,
96
+ logging_dir = None,
97
+ logging_strategy = 'steps',
98
+ logging_first_step = False,
99
+ logging_steps = 1,
100
+ logging_nan_inf_filter = False,
101
+ save_strategy = 'steps',
102
+ save_steps = 500,
103
+ save_total_limit = None,
104
+ save_safetensors = True,
105
+ save_on_each_node = False,
106
+ save_only_model = False,
107
+ restore_callback_states_from_checkpoint = False,
108
+ no_cuda = False,
109
+ use_cpu = False,
110
+ use_mps_device = False,
111
+ seed = 3407,
112
+ data_seed = 3407,
113
+ jit_mode_eval = False,
114
+ use_ipex = False,
115
+ bf16 = False,
116
+ fp16 = False,
117
+ fp16_opt_level = 'O1',
118
+ half_precision_backend = 'auto',
119
+ bf16_full_eval = False,
120
+ fp16_full_eval = False,
121
+ tf32 = None,
122
+ local_rank = -1,
123
+ ddp_backend = None,
124
+ tpu_num_cores = None,
125
+ tpu_metrics_debug = False,
126
+ debug = '',
127
+ dataloader_drop_last = False,
128
+ eval_steps = None,
129
+ dataloader_num_workers = 0,
130
+ dataloader_prefetch_factor = None,
131
+ past_index = -1,
132
+ run_name = None,
133
+ disable_tqdm = None,
134
+ remove_unused_columns = True,
135
+ label_names = None,
136
+ load_best_model_at_end = False,
137
+ metric_for_best_model = None,
138
+ greater_is_better = None,
139
+ ignore_data_skip = False,
140
+ fsdp = '',
141
+ fsdp_min_num_params = 0,
142
+ fsdp_config = None,
143
+ tp_size = 0,
144
+ fsdp_transformer_layer_cls_to_wrap = None,
145
+ accelerator_config = None,
146
+ deepspeed = None,
147
+ label_smoothing_factor = 0.0,
148
+ optim = 'adamw_8bit',
149
+ optim_args = None,
150
+ adafactor = False,
151
+ group_by_length = False,
152
+ length_column_name = 'length',
153
+ report_to = None,
154
+ ddp_find_unused_parameters = None,
155
+ ddp_bucket_cap_mb = None,
156
+ ddp_broadcast_buffers = None,
157
+ dataloader_pin_memory = True,
158
+ dataloader_persistent_workers = False,
159
+ skip_memory_metrics = True,
160
+ use_legacy_prediction_loop = False,
161
+ push_to_hub = False,
162
+ resume_from_checkpoint = None,
163
+ hub_model_id = None,
164
+ hub_strategy = 'every_save',
165
+ hub_token = None,
166
+ hub_private_repo = None,
167
+ hub_always_push = False,
168
+ gradient_checkpointing = False,
169
+ gradient_checkpointing_kwargs = None,
170
+ include_inputs_for_metrics = False,
171
+ eval_do_concat_batches = True,
172
+ fp16_backend = 'auto',
173
+ push_to_hub_model_id = None,
174
+ push_to_hub_organization = None,
175
+ push_to_hub_token = None,
176
+ mp_parameters = '',
177
+ auto_find_batch_size = False,
178
+ full_determinism = False,
179
+ torchdynamo = None,
180
+ ray_scope = 'last',
181
+ ddp_timeout = 1800,
182
+ torch_compile = False,
183
+ torch_compile_backend = None,
184
+ torch_compile_mode = None,
185
+ include_tokens_per_second = False,
186
+ include_num_input_tokens_seen = False,
187
+ neftune_noise_alpha = None,
188
+ optim_target_modules = None,
189
+ batch_eval_metrics = False,
190
+ eval_on_start = False,
191
+ use_liger_kernel = False,
192
+ eval_use_gather_object = False,
193
+ average_tokens_across_devices = False,
194
+ reward_model_path = None,
195
+ judge = None,
196
+ max_new_tokens = 64,
197
+ max_length = 512,
198
+ temperature = 0.9,
199
+ missing_eos_penalty = None,
200
+ loss_type = 'sigmoid',
201
+ dataset_num_proc = None,
202
+ disable_dropout = True,
203
+ use_vllm = False,
204
+ ds3_gather_for_generation = True,
205
+ vllm_sampling_params = None,
206
+ unsloth_num_chunks = -1,
207
+ **kwargs,
208
+ ):
209
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
210
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
211
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
212
+ output_dir = 'unsloth_training_checkpoints'
213
+ save_strategy = 'no'
214
+ if dataset_num_proc is None:
215
+ from multiprocessing import cpu_count
216
+ dataset_num_proc = cpu_count()
217
+
218
+ super().__init__(
219
+ output_dir = output_dir,
220
+ overwrite_output_dir = overwrite_output_dir,
221
+ do_train = do_train,
222
+ do_eval = do_eval,
223
+ do_predict = do_predict,
224
+ eval_strategy = eval_strategy,
225
+ prediction_loss_only = prediction_loss_only,
226
+ per_device_train_batch_size = per_device_train_batch_size,
227
+ per_device_eval_batch_size = per_device_eval_batch_size,
228
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
229
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
230
+ gradient_accumulation_steps = gradient_accumulation_steps,
231
+ eval_accumulation_steps = eval_accumulation_steps,
232
+ eval_delay = eval_delay,
233
+ torch_empty_cache_steps = torch_empty_cache_steps,
234
+ learning_rate = learning_rate,
235
+ weight_decay = weight_decay,
236
+ adam_beta1 = adam_beta1,
237
+ adam_beta2 = adam_beta2,
238
+ adam_epsilon = adam_epsilon,
239
+ max_grad_norm = max_grad_norm,
240
+ num_train_epochs = num_train_epochs,
241
+ max_steps = max_steps,
242
+ lr_scheduler_type = lr_scheduler_type,
243
+ warmup_ratio = warmup_ratio,
244
+ warmup_steps = warmup_steps,
245
+ log_level = log_level,
246
+ log_level_replica = log_level_replica,
247
+ log_on_each_node = log_on_each_node,
248
+ logging_dir = logging_dir,
249
+ logging_strategy = logging_strategy,
250
+ logging_first_step = logging_first_step,
251
+ logging_steps = logging_steps,
252
+ logging_nan_inf_filter = logging_nan_inf_filter,
253
+ save_strategy = save_strategy,
254
+ save_steps = save_steps,
255
+ save_total_limit = save_total_limit,
256
+ save_safetensors = save_safetensors,
257
+ save_on_each_node = save_on_each_node,
258
+ save_only_model = save_only_model,
259
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
260
+ no_cuda = no_cuda,
261
+ use_cpu = use_cpu,
262
+ use_mps_device = use_mps_device,
263
+ seed = seed,
264
+ data_seed = data_seed,
265
+ jit_mode_eval = jit_mode_eval,
266
+ use_ipex = use_ipex,
267
+ bf16 = bf16,
268
+ fp16 = fp16,
269
+ fp16_opt_level = fp16_opt_level,
270
+ half_precision_backend = half_precision_backend,
271
+ bf16_full_eval = bf16_full_eval,
272
+ fp16_full_eval = fp16_full_eval,
273
+ tf32 = tf32,
274
+ local_rank = local_rank,
275
+ ddp_backend = ddp_backend,
276
+ tpu_num_cores = tpu_num_cores,
277
+ tpu_metrics_debug = tpu_metrics_debug,
278
+ debug = debug,
279
+ dataloader_drop_last = dataloader_drop_last,
280
+ eval_steps = eval_steps,
281
+ dataloader_num_workers = dataloader_num_workers,
282
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
283
+ past_index = past_index,
284
+ run_name = run_name,
285
+ disable_tqdm = disable_tqdm,
286
+ remove_unused_columns = remove_unused_columns,
287
+ label_names = label_names,
288
+ load_best_model_at_end = load_best_model_at_end,
289
+ metric_for_best_model = metric_for_best_model,
290
+ greater_is_better = greater_is_better,
291
+ ignore_data_skip = ignore_data_skip,
292
+ fsdp = fsdp,
293
+ fsdp_min_num_params = fsdp_min_num_params,
294
+ fsdp_config = fsdp_config,
295
+ tp_size = tp_size,
296
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
297
+ accelerator_config = accelerator_config,
298
+ deepspeed = deepspeed,
299
+ label_smoothing_factor = label_smoothing_factor,
300
+ optim = optim,
301
+ optim_args = optim_args,
302
+ adafactor = adafactor,
303
+ group_by_length = group_by_length,
304
+ length_column_name = length_column_name,
305
+ report_to = report_to,
306
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
307
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
308
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
309
+ dataloader_pin_memory = dataloader_pin_memory,
310
+ dataloader_persistent_workers = dataloader_persistent_workers,
311
+ skip_memory_metrics = skip_memory_metrics,
312
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
313
+ push_to_hub = push_to_hub,
314
+ resume_from_checkpoint = resume_from_checkpoint,
315
+ hub_model_id = hub_model_id,
316
+ hub_strategy = hub_strategy,
317
+ hub_token = hub_token,
318
+ hub_private_repo = hub_private_repo,
319
+ hub_always_push = hub_always_push,
320
+ gradient_checkpointing = gradient_checkpointing,
321
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
322
+ include_inputs_for_metrics = include_inputs_for_metrics,
323
+ eval_do_concat_batches = eval_do_concat_batches,
324
+ fp16_backend = fp16_backend,
325
+ push_to_hub_model_id = push_to_hub_model_id,
326
+ push_to_hub_organization = push_to_hub_organization,
327
+ push_to_hub_token = push_to_hub_token,
328
+ mp_parameters = mp_parameters,
329
+ auto_find_batch_size = auto_find_batch_size,
330
+ full_determinism = full_determinism,
331
+ torchdynamo = torchdynamo,
332
+ ray_scope = ray_scope,
333
+ ddp_timeout = ddp_timeout,
334
+ torch_compile = torch_compile,
335
+ torch_compile_backend = torch_compile_backend,
336
+ torch_compile_mode = torch_compile_mode,
337
+ include_tokens_per_second = include_tokens_per_second,
338
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
339
+ neftune_noise_alpha = neftune_noise_alpha,
340
+ optim_target_modules = optim_target_modules,
341
+ batch_eval_metrics = batch_eval_metrics,
342
+ eval_on_start = eval_on_start,
343
+ use_liger_kernel = use_liger_kernel,
344
+ eval_use_gather_object = eval_use_gather_object,
345
+ average_tokens_across_devices = average_tokens_across_devices,
346
+ reward_model_path = reward_model_path,
347
+ judge = judge,
348
+ max_new_tokens = max_new_tokens,
349
+ max_length = max_length,
350
+ temperature = temperature,
351
+ missing_eos_penalty = missing_eos_penalty,
352
+ loss_type = loss_type,
353
+ dataset_num_proc = dataset_num_proc,
354
+ disable_dropout = disable_dropout,
355
+ use_vllm = use_vllm,
356
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
357
+ self.vllm_sampling_params = vllm_sampling_params
358
+ self.unsloth_num_chunks = unsloth_num_chunks
359
+ pass
360
+
361
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
362
+ r""""""
363
+
364
+ _tag_names = ["trl", "nash-md"]
365
+
366
+ def __init__(
367
+ self,
368
+ model: Union[PreTrainedModel, nn.Module] = None,
369
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
370
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
371
+ judge: Optional[BasePairwiseJudge] = None,
372
+ args: Optional[NashMDConfig] = None,
373
+ data_collator: Optional[Callable] = None,
374
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
375
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
376
+ processing_class: Optional[
377
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
378
+ ] = None,
379
+ peft_config: Optional[dict] = None,
380
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
381
+ callbacks: Optional[list[TrainerCallback]] = None,
382
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
383
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
384
+ ) -> None:
385
+ super().__init__(
386
+ model=model,
387
+ ref_model=ref_model,
388
+ reward_model=reward_model,
389
+ judge=judge,
390
+ args=args,
391
+ data_collator=data_collator,
392
+ train_dataset=train_dataset,
393
+ eval_dataset=eval_dataset,
394
+ processing_class=processing_class,
395
+ reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
396
+ peft_config=peft_config,
397
+ compute_metrics=compute_metrics,
398
+ callbacks=callbacks,
399
+ optimizers=optimizers,
400
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
401
+ )
402
+
403
+ self._mixture_coef = self.args.mixture_coef
404
+
405
+ # Overwrite the stats dictionary to include NashMD specific statistics
406
+ self.stats = {
407
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
408
+ # Add "mixture_coef"
409
+ "loss/kl": [],
410
+ "objective/entropy": [],
411
+ "loss/score": [],
412
+ "rewards/probabilities": [],
413
+ "rewards/accuracies": [],
414
+ "rewards/margins": [],
415
+ "logps/chosen": [],
416
+ "logps/rejected": [],
417
+ "val/model_contain_eos_token": [],
418
+ "val/ref_contain_eos_token": [],
419
+ "beta": [],
420
+ "mixture_coef": [],
421
+ }
422
+ if self.reward_model is not None:
423
+ self.stats["rewards/chosen"] = []
424
+ self.stats["rewards/rejected"] = []
425
+
426
+ @property
427
+ def mixture_coef(self):
428
+ if isinstance(self._mixture_coef, list):
429
+ epoch = self.state.epoch
430
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
431
+ else:
432
+ return self._mixture_coef
433
+
434
+ def _generate_completions(self, model, prompts):
435
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
436
+ model_output = unwrapped_model.generate(
437
+ input_ids=prompts["input_ids"],
438
+ attention_mask=prompts["attention_mask"],
439
+ generation_config=self.generation_config,
440
+ )
441
+
442
+ ref_model = model if self.ref_model is None else self.ref_model
443
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
444
+ mixture_model = GeometricMixtureWrapper(
445
+ model=unwrapped_model,
446
+ ref_model=unwrapped_ref_model,
447
+ generation_config=self.generation_config,
448
+ mixture_coef=self.mixture_coef,
449
+ device=self.accelerator.device,
450
+ )
451
+
452
+ mixture_output = mixture_model.generate(
453
+ input_ids=prompts["input_ids"],
454
+ attention_mask=prompts["attention_mask"],
455
+ generation_config=self.generation_config,
456
+ )
457
+
458
+ return model_output, mixture_output
459
+
460
+ def _process_completions(self, model_output, mixture_output, prompts):
461
+ context_length = prompts["input_ids"].shape[1]
462
+
463
+ # Process model completions
464
+ model_completion_ids = model_output[:, context_length:]
465
+ model_completion_ids, model_completion_mask = truncate_right(
466
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
467
+ )
468
+ model_data = {
469
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
470
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
471
+ "raw": prompts["raw"],
472
+ }
473
+
474
+ # Process reference model completions
475
+ mixture_completion_ids = mixture_output[:, context_length:]
476
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
477
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
478
+ )
479
+ mixture_data = {
480
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
481
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
482
+ "raw": prompts["raw"],
483
+ }
484
+
485
+ return model_data, mixture_data
486
+
487
+ def _compute_rewards(self, model_data, mixture_data, context_length):
488
+ with torch.no_grad():
489
+ _, model_scores, _ = get_reward(
490
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
491
+ )
492
+ _, mixture_scores, _ = get_reward(
493
+ self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
494
+ )
495
+
496
+ # Apply EOS penalty if needed
497
+ if self.args.missing_eos_penalty is not None:
498
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
499
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
500
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
501
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
502
+
503
+ return model_scores, mixture_scores
504
+
505
+ def _compute_judge(self, model_data, mixture_data, context_length):
506
+ prompts = model_data["raw"]
507
+ model_data_completions = self.processing_class.batch_decode(
508
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
509
+ )
510
+ model_data_completions = [completion.strip() for completion in model_data_completions]
511
+
512
+ mixture_data_completions = self.processing_class.batch_decode(
513
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
514
+ )
515
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
516
+ if is_conversational({"prompt": prompts[0]}):
517
+ model_data_completions = [
518
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
519
+ ]
520
+ environment = jinja2.Environment()
521
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
522
+ prompts = [template.render(messages=message) for message in prompts]
523
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
524
+
525
+ mixture_data_completions = [
526
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
527
+ ]
528
+ mixture_data_completions = [
529
+ template.render(messages=completion) for completion in mixture_data_completions
530
+ ]
531
+
532
+ probability = self.judge.judge(
533
+ prompts,
534
+ list(zip(model_data_completions, mixture_data_completions)),
535
+ return_scores=True,
536
+ )
537
+ return torch.tensor(probability, device=model_data["input_ids"].device)
538
+
539
+ def _compute_logprobs(self, model, model_data, context_length):
540
+ def compute_logprobs_for_data(m, data):
541
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
542
+ logits = output.logits[:, context_length - 1 : -1]
543
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
544
+ return token_logprobs
545
+
546
+ # Compute logprobs for model completions under the model
547
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
548
+
549
+ # Compute logprobs of model completions under the reference model
550
+ with torch.no_grad():
551
+ if self.ref_model is None:
552
+ with model.disable_adapter():
553
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
554
+ else:
555
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
556
+
557
+ # Mask padding tokens
558
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
559
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
560
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
561
+
562
+ return (model_logprobs_model_data, ref_logprobs_model_data)
563
+
564
+ def _compute_losses(
565
+ self,
566
+ model_logprobs_model_data,
567
+ ref_logprobs_model_data,
568
+ probability,
569
+ ):
570
+ # reinforce score where 0.5 is a control variate
571
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
572
+
573
+ # kl divergence via reinforce
574
+ with torch.no_grad():
575
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
576
+ kl_div_log = log_ratio.sum(1)
577
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
578
+
579
+ # final loss
580
+ loss = self.beta * kl_div_loss - score
581
+
582
+ return loss.mean(), score, kl_div_log
583
+
584
+ def _log_statistics(
585
+ self,
586
+ model_data,
587
+ mixture_data,
588
+ model_logprobs_model_data,
589
+ ref_logprobs_model_data,
590
+ probability,
591
+ score,
592
+ kl_div,
593
+ context_length,
594
+ model_scores=None,
595
+ mixture_scores=None,
596
+ ):
597
+ # Helper function to gather and compute mean
598
+ def gather_mean(tensor):
599
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
600
+
601
+ # Log score
602
+ self.stats["loss/score"].append(gather_mean(score))
603
+ # Log KL divergence
604
+ self.stats["loss/kl"].append(gather_mean(kl_div))
605
+
606
+ # Log logprobs
607
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
608
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
609
+
610
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
611
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
612
+
613
+ # Log rewards
614
+ if self.reward_model is not None:
615
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
616
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
617
+
618
+ # Log probabilities
619
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
620
+
621
+ # Calculate entropy for model data
622
+ entropy_model_data = -model_logprobs_model_data.sum(1)
623
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
624
+
625
+ # Calculate margins
626
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
627
+ self.stats["rewards/margins"].append(gather_mean(margin))
628
+
629
+ # Calculate accuracy
630
+ accuracy = (margin > 0).float()
631
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
632
+
633
+ # Log EOS token statistics
634
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
635
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
636
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
637
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
638
+
639
+ # Log beta and mixture coef
640
+ self.stats["beta"].append(self.beta)
641
+ self.stats["mixture_coef"].append(self.mixture_coef)
642
+
643
+ def training_step(
644
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
645
+ ) -> torch.Tensor:
646
+ model.train()
647
+
648
+ # Apply chat template and tokenize the input
649
+ batch_size = len(next(iter(inputs.values())))
650
+ prompts = inputs["prompt"]
651
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
652
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
653
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
654
+ inputs = self.data_collator(inputs)
655
+
656
+ # need the prompt_ only
657
+ inputs = self._prepare_inputs(inputs)
658
+ context_length = inputs["prompt_input_ids"].shape[1]
659
+ prompts = {
660
+ "input_ids": inputs["prompt_input_ids"],
661
+ "attention_mask": inputs["prompt_attention_mask"],
662
+ "raw": prompts,
663
+ }
664
+ del inputs
665
+
666
+ # Sample completions from both the model and the reference model
667
+ model_output, mixture_output = self._generate_completions(model, prompts)
668
+
669
+ # Process model completions
670
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
671
+
672
+ # Compute rewards
673
+ if self.reward_model is not None:
674
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
675
+ # probability of the model data vs the mixture data
676
+ probability = F.sigmoid(model_scores - mixture_scores)
677
+ else:
678
+ model_scores, mixture_scores = None, None
679
+ probability = self._compute_judge(model_data, mixture_data, context_length)
680
+
681
+ # Compute logprobs
682
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
683
+
684
+ # Compute loss
685
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
686
+
687
+ # Log everything
688
+ self._log_statistics(
689
+ model_data,
690
+ mixture_data,
691
+ model_logprobs_model_data.detach(),
692
+ ref_logprobs_model_data,
693
+ probability,
694
+ score.detach(),
695
+ kl_div.detach(),
696
+ context_length,
697
+ model_scores,
698
+ mixture_scores,
699
+ )
700
+
701
+ if (
702
+ self.args.torch_empty_cache_steps is not None
703
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
704
+ ):
705
+ empty_cache()
706
+
707
+ kwargs = {}
708
+ # For LOMO optimizers you need to explicitly use the learning rate
709
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
710
+ kwargs["learning_rate"] = self._get_learning_rate()
711
+
712
+ if self.args.n_gpu > 1:
713
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
714
+
715
+ if self.use_apex:
716
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
717
+ scaled_loss.backward()
718
+ else:
719
+ self.accelerator.backward(loss, **kwargs)
720
+
721
+ return loss.detach() / self.args.gradient_accumulation_steps
722
+
723
+ def create_model_card(
724
+ self,
725
+ model_name: Optional[str] = None,
726
+ dataset_name: Optional[str] = None,
727
+ tags: Union[str, list[str], None] = None,
728
+ ):
729
+ """
730
+ Creates a draft of a model card using the information available to the `Trainer`.
731
+
732
+ Args:
733
+ model_name (`str` or `None`, *optional*, defaults to `None`):
734
+ Name of the model.
735
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
736
+ Name of the dataset used for training.
737
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
738
+ Tags to be associated with the model card.
739
+ """
740
+ if not self.is_world_process_zero():
741
+ return
742
+
743
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
744
+ base_model = self.model.config._name_or_path
745
+ else:
746
+ base_model = None
747
+
748
+ tags = tags or []
749
+ if isinstance(tags, str):
750
+ tags = [tags]
751
+
752
+ if hasattr(self.model.config, "unsloth_version"):
753
+ tags.append("unsloth")
754
+
755
+ citation = textwrap.dedent("""\
756
+ @inproceedings{munos2024nash,
757
+ title = {{Nash Learning from Human Feedback}},
758
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
759
+ year = 2024,
760
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
761
+ publisher = {OpenReview.net},
762
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
763
+ }""")
764
+
765
+ model_card = generate_model_card(
766
+ base_model=base_model,
767
+ model_name=model_name,
768
+ hub_model_id=self.hub_model_id,
769
+ dataset_name=dataset_name,
770
+ tags=tags,
771
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
772
+ comet_url=get_comet_experiment_url(),
773
+ trainer_name="Nash-MD",
774
+ trainer_citation=citation,
775
+ paper_title="Nash Learning from Human Feedback",
776
+ paper_id="2312.00886",
777
+ )
778
+
779
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
780
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
781
+ """
782
+
783
+ Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
784
+
785
+ Args:
786
+ model (`transformers.PreTrainedModel`):
787
+ The model to train, preferably an `AutoModelForCausalLM`.
788
+ ref_model (`PreTrainedModelWrapper`):
789
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
790
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
791
+ reward_model (`transformers.PreTrainedModel`):
792
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
793
+ judge (`BasePairwiseJudge`):
794
+ The judge to use for pairwise comparison of model completions.
795
+ args (`NashMDConfig`):
796
+ The NashMD config arguments to use for training.
797
+ data_collator (`transformers.DataCollator`):
798
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
799
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
800
+ train_dataset (`datasets.Dataset`):
801
+ The dataset to use for training.
802
+ eval_dataset (`datasets.Dataset`):
803
+ The dataset to use for evaluation.
804
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
805
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
806
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
807
+ reuse the fine-tuned model.
808
+ peft_config (`dict`):
809
+ The peft config to use for training.
810
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
811
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
812
+ a dictionary string to metric values.
813
+ callbacks (`list[transformers.TrainerCallback]`):
814
+ The callbacks to use for training.
815
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
816
+ The optimizer and scheduler to use for training.
817
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
818
+ The function to use to preprocess the logits before computing the metrics.
819
+
820
+ """
821
+ def __init__(
822
+ self,
823
+ model = None,
824
+ ref_model = None,
825
+ reward_model = None,
826
+ judge = None,
827
+ args = None,
828
+ data_collator = None,
829
+ train_dataset = None,
830
+ eval_dataset = None,
831
+ processing_class = None,
832
+ peft_config = None,
833
+ compute_metrics = None,
834
+ callbacks = None,
835
+ preprocess_logits_for_metrics = None,
836
+ **kwargs
837
+ ):
838
+ if args is None: args = UnslothNashMDConfig()
839
+ use_bf16 = getattr(args, 'bf16', False)
840
+ use_fp16 = getattr(args, 'fp16', False)
841
+ force_float32 = False
842
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
843
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
844
+ force_float32 = True
845
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
846
+ dtype = getattr(model.config, 'torch_dtype', None)
847
+ if dtype is None: dtype = model.get_input_embeddings().dtype
848
+ from unsloth_zoo.utils import _get_dtype
849
+ dtype = _get_dtype(dtype)
850
+ float16 = dtype == torch.float16
851
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
852
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
853
+ if force_float32:
854
+ args.fp16 = False
855
+ args.bf16 = False
856
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
857
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
858
+ args.fp16 = float16
859
+ args.bf16 = not float16
860
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
861
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
862
+ args.eval_strategy = 'steps'
863
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
864
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
865
+ if ga_steps is not None and ga_steps > 1:
866
+ from transformers import __version__ as transformers_version
867
+ if Version(transformers_version) <= Version('4.45.2'):
868
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
869
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
870
+ if getattr(args, 'eval_strategy', 'no') != 'no':
871
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
872
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
873
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
874
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
875
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
876
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
877
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
878
+ if force_float32:
879
+ args.bf16_full_eval = False
880
+ args.fp16_full_eval = False
881
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
882
+ args.bf16_full_eval = True
883
+ args.fp16_full_eval = False
884
+ elif not bf16_full_eval and not fp16_full_eval:
885
+ args.bf16_full_eval = args.bf16
886
+ args.fp16_full_eval = args.fp16
887
+ _output_logits = False
888
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
889
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
890
+ if _output_logits:
891
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
892
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
893
+ pass
894
+ else:
895
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
896
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
897
+ if args_max_seq_length is None and model_max_seq_length is not None:
898
+ max_seq_length = model.max_seq_length
899
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
900
+ if model is not None and hasattr(model, 'for_training'):
901
+ model.for_training()
902
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
903
+ if 'processing_class' in locals():
904
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
905
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
906
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
907
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
908
+ if not isinstance(data_collator, UnslothVisionDataCollator):
909
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
910
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
911
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
912
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
913
+ else:
914
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
915
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
916
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
917
+ if not isinstance(data_collator, UnslothVisionDataCollator):
918
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
919
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
920
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
921
+ else:
922
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
923
+ other_metrics = []
924
+
925
+ from unsloth_zoo.logging_utils import PatchRLStatistics
926
+ PatchRLStatistics('nash_md_trainer', other_metrics)
927
+
928
+ super().__init__(
929
+ model = model,
930
+ ref_model = ref_model,
931
+ reward_model = reward_model,
932
+ judge = judge,
933
+ args = args,
934
+ data_collator = data_collator,
935
+ train_dataset = train_dataset,
936
+ eval_dataset = eval_dataset,
937
+ processing_class = processing_class,
938
+ peft_config = peft_config,
939
+ compute_metrics = compute_metrics,
940
+ callbacks = callbacks,
941
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
942
+ if hasattr(self, 'neftune_hook_handle'):
943
+ self.neftune_hook_handle.remove()
944
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
945
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
946
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
947
+ pass
948
+
949
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothORPOConfig(ORPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`ORPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
66
+ it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
67
+ disable_dropout (`bool`, *optional*, defaults to `True`):
68
+ Whether to disable dropout in the model.
69
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
70
+ Label pad token id. This argument is required if you want to use the default data collator.
71
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
72
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
73
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
74
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
75
+ This argument is required if you want to use the default data collator.
76
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
77
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
78
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
79
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
80
+ you need to specify if the model returned by the callable is an encoder-decoder model.
81
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
83
+ string.
84
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
+ Number of processes to use for processing the dataset.
86
+
87
+ """
88
+ vllm_sampling_params: Optional[Any] = field(
89
+ default = None,
90
+ metadata = {'help': 'vLLM SamplingParams'},
91
+ )
92
+ unsloth_num_chunks : Optional[int] = field(
93
+ default = -1,
94
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
95
+ )
96
+ def __init__(
97
+ self,
98
+ output_dir = None,
99
+ overwrite_output_dir = None,
100
+ do_train = False,
101
+ do_eval = False,
102
+ do_predict = False,
103
+ eval_strategy = 'no',
104
+ prediction_loss_only = False,
105
+ per_device_train_batch_size = 4,
106
+ per_device_eval_batch_size = 4,
107
+ per_gpu_train_batch_size = None,
108
+ per_gpu_eval_batch_size = None,
109
+ gradient_accumulation_steps = 2,
110
+ eval_accumulation_steps = 2,
111
+ eval_delay = 0,
112
+ torch_empty_cache_steps = 250,
113
+ learning_rate = 5e-05,
114
+ weight_decay = 0.01,
115
+ adam_beta1 = 0.9,
116
+ adam_beta2 = 0.999,
117
+ adam_epsilon = 1e-08,
118
+ max_grad_norm = 1.0,
119
+ num_train_epochs = 3.0,
120
+ max_steps = -1,
121
+ lr_scheduler_type = 'linear',
122
+ warmup_ratio = 0.1,
123
+ warmup_steps = 0,
124
+ log_level = 'passive',
125
+ log_level_replica = 'warning',
126
+ log_on_each_node = True,
127
+ logging_dir = None,
128
+ logging_strategy = 'steps',
129
+ logging_first_step = False,
130
+ logging_steps = 1,
131
+ logging_nan_inf_filter = False,
132
+ save_strategy = 'steps',
133
+ save_steps = 500,
134
+ save_total_limit = None,
135
+ save_safetensors = True,
136
+ save_on_each_node = False,
137
+ save_only_model = False,
138
+ restore_callback_states_from_checkpoint = False,
139
+ no_cuda = False,
140
+ use_cpu = False,
141
+ use_mps_device = False,
142
+ seed = 3407,
143
+ data_seed = 3407,
144
+ jit_mode_eval = False,
145
+ use_ipex = False,
146
+ bf16 = False,
147
+ fp16 = False,
148
+ fp16_opt_level = 'O1',
149
+ half_precision_backend = 'auto',
150
+ bf16_full_eval = False,
151
+ fp16_full_eval = False,
152
+ tf32 = None,
153
+ local_rank = -1,
154
+ ddp_backend = None,
155
+ tpu_num_cores = None,
156
+ tpu_metrics_debug = False,
157
+ debug = '',
158
+ dataloader_drop_last = False,
159
+ eval_steps = None,
160
+ dataloader_num_workers = 0,
161
+ dataloader_prefetch_factor = None,
162
+ past_index = -1,
163
+ run_name = None,
164
+ disable_tqdm = None,
165
+ remove_unused_columns = True,
166
+ label_names = None,
167
+ load_best_model_at_end = False,
168
+ metric_for_best_model = None,
169
+ greater_is_better = None,
170
+ ignore_data_skip = False,
171
+ fsdp = '',
172
+ fsdp_min_num_params = 0,
173
+ fsdp_config = None,
174
+ tp_size = 0,
175
+ fsdp_transformer_layer_cls_to_wrap = None,
176
+ accelerator_config = None,
177
+ deepspeed = None,
178
+ label_smoothing_factor = 0.0,
179
+ optim = 'adamw_8bit',
180
+ optim_args = None,
181
+ adafactor = False,
182
+ group_by_length = False,
183
+ length_column_name = 'length',
184
+ report_to = None,
185
+ ddp_find_unused_parameters = None,
186
+ ddp_bucket_cap_mb = None,
187
+ ddp_broadcast_buffers = None,
188
+ dataloader_pin_memory = True,
189
+ dataloader_persistent_workers = False,
190
+ skip_memory_metrics = True,
191
+ use_legacy_prediction_loop = False,
192
+ push_to_hub = False,
193
+ resume_from_checkpoint = None,
194
+ hub_model_id = None,
195
+ hub_strategy = 'every_save',
196
+ hub_token = None,
197
+ hub_private_repo = None,
198
+ hub_always_push = False,
199
+ gradient_checkpointing = False,
200
+ gradient_checkpointing_kwargs = None,
201
+ include_inputs_for_metrics = False,
202
+ eval_do_concat_batches = True,
203
+ fp16_backend = 'auto',
204
+ push_to_hub_model_id = None,
205
+ push_to_hub_organization = None,
206
+ push_to_hub_token = None,
207
+ mp_parameters = '',
208
+ auto_find_batch_size = False,
209
+ full_determinism = False,
210
+ torchdynamo = None,
211
+ ray_scope = 'last',
212
+ ddp_timeout = 1800,
213
+ torch_compile = False,
214
+ torch_compile_backend = None,
215
+ torch_compile_mode = None,
216
+ include_tokens_per_second = False,
217
+ include_num_input_tokens_seen = False,
218
+ neftune_noise_alpha = None,
219
+ optim_target_modules = None,
220
+ batch_eval_metrics = False,
221
+ eval_on_start = False,
222
+ use_liger_kernel = False,
223
+ eval_use_gather_object = False,
224
+ average_tokens_across_devices = False,
225
+ max_length = 1024,
226
+ max_prompt_length = 512,
227
+ max_completion_length = None,
228
+ beta = 0.1,
229
+ disable_dropout = True,
230
+ label_pad_token_id = -100,
231
+ padding_value = None,
232
+ truncation_mode = 'keep_end',
233
+ generate_during_eval = False,
234
+ is_encoder_decoder = None,
235
+ model_init_kwargs = None,
236
+ dataset_num_proc = None,
237
+ vllm_sampling_params = None,
238
+ unsloth_num_chunks = -1,
239
+ **kwargs,
240
+ ):
241
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
242
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
243
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
244
+ output_dir = 'unsloth_training_checkpoints'
245
+ save_strategy = 'no'
246
+ if dataset_num_proc is None:
247
+ from multiprocessing import cpu_count
248
+ dataset_num_proc = cpu_count()
249
+
250
+ super().__init__(
251
+ output_dir = output_dir,
252
+ overwrite_output_dir = overwrite_output_dir,
253
+ do_train = do_train,
254
+ do_eval = do_eval,
255
+ do_predict = do_predict,
256
+ eval_strategy = eval_strategy,
257
+ prediction_loss_only = prediction_loss_only,
258
+ per_device_train_batch_size = per_device_train_batch_size,
259
+ per_device_eval_batch_size = per_device_eval_batch_size,
260
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
261
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
262
+ gradient_accumulation_steps = gradient_accumulation_steps,
263
+ eval_accumulation_steps = eval_accumulation_steps,
264
+ eval_delay = eval_delay,
265
+ torch_empty_cache_steps = torch_empty_cache_steps,
266
+ learning_rate = learning_rate,
267
+ weight_decay = weight_decay,
268
+ adam_beta1 = adam_beta1,
269
+ adam_beta2 = adam_beta2,
270
+ adam_epsilon = adam_epsilon,
271
+ max_grad_norm = max_grad_norm,
272
+ num_train_epochs = num_train_epochs,
273
+ max_steps = max_steps,
274
+ lr_scheduler_type = lr_scheduler_type,
275
+ warmup_ratio = warmup_ratio,
276
+ warmup_steps = warmup_steps,
277
+ log_level = log_level,
278
+ log_level_replica = log_level_replica,
279
+ log_on_each_node = log_on_each_node,
280
+ logging_dir = logging_dir,
281
+ logging_strategy = logging_strategy,
282
+ logging_first_step = logging_first_step,
283
+ logging_steps = logging_steps,
284
+ logging_nan_inf_filter = logging_nan_inf_filter,
285
+ save_strategy = save_strategy,
286
+ save_steps = save_steps,
287
+ save_total_limit = save_total_limit,
288
+ save_safetensors = save_safetensors,
289
+ save_on_each_node = save_on_each_node,
290
+ save_only_model = save_only_model,
291
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
292
+ no_cuda = no_cuda,
293
+ use_cpu = use_cpu,
294
+ use_mps_device = use_mps_device,
295
+ seed = seed,
296
+ data_seed = data_seed,
297
+ jit_mode_eval = jit_mode_eval,
298
+ use_ipex = use_ipex,
299
+ bf16 = bf16,
300
+ fp16 = fp16,
301
+ fp16_opt_level = fp16_opt_level,
302
+ half_precision_backend = half_precision_backend,
303
+ bf16_full_eval = bf16_full_eval,
304
+ fp16_full_eval = fp16_full_eval,
305
+ tf32 = tf32,
306
+ local_rank = local_rank,
307
+ ddp_backend = ddp_backend,
308
+ tpu_num_cores = tpu_num_cores,
309
+ tpu_metrics_debug = tpu_metrics_debug,
310
+ debug = debug,
311
+ dataloader_drop_last = dataloader_drop_last,
312
+ eval_steps = eval_steps,
313
+ dataloader_num_workers = dataloader_num_workers,
314
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
315
+ past_index = past_index,
316
+ run_name = run_name,
317
+ disable_tqdm = disable_tqdm,
318
+ remove_unused_columns = remove_unused_columns,
319
+ label_names = label_names,
320
+ load_best_model_at_end = load_best_model_at_end,
321
+ metric_for_best_model = metric_for_best_model,
322
+ greater_is_better = greater_is_better,
323
+ ignore_data_skip = ignore_data_skip,
324
+ fsdp = fsdp,
325
+ fsdp_min_num_params = fsdp_min_num_params,
326
+ fsdp_config = fsdp_config,
327
+ tp_size = tp_size,
328
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
329
+ accelerator_config = accelerator_config,
330
+ deepspeed = deepspeed,
331
+ label_smoothing_factor = label_smoothing_factor,
332
+ optim = optim,
333
+ optim_args = optim_args,
334
+ adafactor = adafactor,
335
+ group_by_length = group_by_length,
336
+ length_column_name = length_column_name,
337
+ report_to = report_to,
338
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
339
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
340
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
341
+ dataloader_pin_memory = dataloader_pin_memory,
342
+ dataloader_persistent_workers = dataloader_persistent_workers,
343
+ skip_memory_metrics = skip_memory_metrics,
344
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
345
+ push_to_hub = push_to_hub,
346
+ resume_from_checkpoint = resume_from_checkpoint,
347
+ hub_model_id = hub_model_id,
348
+ hub_strategy = hub_strategy,
349
+ hub_token = hub_token,
350
+ hub_private_repo = hub_private_repo,
351
+ hub_always_push = hub_always_push,
352
+ gradient_checkpointing = gradient_checkpointing,
353
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
354
+ include_inputs_for_metrics = include_inputs_for_metrics,
355
+ eval_do_concat_batches = eval_do_concat_batches,
356
+ fp16_backend = fp16_backend,
357
+ push_to_hub_model_id = push_to_hub_model_id,
358
+ push_to_hub_organization = push_to_hub_organization,
359
+ push_to_hub_token = push_to_hub_token,
360
+ mp_parameters = mp_parameters,
361
+ auto_find_batch_size = auto_find_batch_size,
362
+ full_determinism = full_determinism,
363
+ torchdynamo = torchdynamo,
364
+ ray_scope = ray_scope,
365
+ ddp_timeout = ddp_timeout,
366
+ torch_compile = torch_compile,
367
+ torch_compile_backend = torch_compile_backend,
368
+ torch_compile_mode = torch_compile_mode,
369
+ include_tokens_per_second = include_tokens_per_second,
370
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
371
+ neftune_noise_alpha = neftune_noise_alpha,
372
+ optim_target_modules = optim_target_modules,
373
+ batch_eval_metrics = batch_eval_metrics,
374
+ eval_on_start = eval_on_start,
375
+ use_liger_kernel = use_liger_kernel,
376
+ eval_use_gather_object = eval_use_gather_object,
377
+ average_tokens_across_devices = average_tokens_across_devices,
378
+ max_length = max_length,
379
+ max_prompt_length = max_prompt_length,
380
+ max_completion_length = max_completion_length,
381
+ beta = beta,
382
+ disable_dropout = disable_dropout,
383
+ label_pad_token_id = label_pad_token_id,
384
+ padding_value = padding_value,
385
+ truncation_mode = truncation_mode,
386
+ generate_during_eval = generate_during_eval,
387
+ is_encoder_decoder = is_encoder_decoder,
388
+ model_init_kwargs = model_init_kwargs,
389
+ dataset_num_proc = dataset_num_proc,**kwargs)
390
+ self.vllm_sampling_params = vllm_sampling_params
391
+ self.unsloth_num_chunks = unsloth_num_chunks
392
+ pass
393
+
394
+ class _UnslothORPOTrainer(Trainer):
395
+ r""""""
396
+
397
+ _tag_names = ["trl", "orpo"]
398
+
399
+ def __init__(
400
+ self,
401
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
402
+ args: Optional[ORPOConfig] = None,
403
+ data_collator: Optional[DataCollator] = None,
404
+ train_dataset: Optional[Dataset] = None,
405
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
406
+ processing_class: Optional[
407
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
408
+ ] = None,
409
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
410
+ callbacks: Optional[list[TrainerCallback]] = None,
411
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
412
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
413
+ peft_config: Optional[dict] = None,
414
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
415
+ ):
416
+ if args.model_init_kwargs is None:
417
+ model_init_kwargs = {}
418
+ elif not isinstance(model, str):
419
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
420
+ else:
421
+ model_init_kwargs = args.model_init_kwargs
422
+ torch_dtype = model_init_kwargs.get("torch_dtype")
423
+ if torch_dtype is not None:
424
+ # Convert to `torch.dtype` if an str is passed
425
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
426
+ torch_dtype = getattr(torch, torch_dtype)
427
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
428
+ raise ValueError(
429
+ f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
430
+ )
431
+ model_init_kwargs["torch_dtype"] = torch_dtype
432
+
433
+ if isinstance(model, str):
434
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
435
+
436
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
437
+ # has been called in order to properly call autocast if needed.
438
+ self._peft_has_been_casted_to_bf16 = False
439
+
440
+ if not is_peft_available() and peft_config is not None:
441
+ raise ValueError(
442
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
443
+ )
444
+ elif is_peft_available() and peft_config is not None:
445
+ # if model is a peft model and we have a peft_config, we merge and unload it first
446
+ if isinstance(model, PeftModel):
447
+ model = model.merge_and_unload()
448
+
449
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
450
+ _support_gc_kwargs = hasattr(
451
+ args, "gradient_checkpointing_kwargs"
452
+ ) and "gradient_checkpointing_kwargs" in list(
453
+ inspect.signature(prepare_model_for_kbit_training).parameters
454
+ )
455
+
456
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
457
+
458
+ if _support_gc_kwargs:
459
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
460
+
461
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
462
+ elif getattr(args, "gradient_checkpointing", False):
463
+ # For backward compatibility with older versions of transformers
464
+ if hasattr(model, "enable_input_require_grads"):
465
+ model.enable_input_require_grads()
466
+ else:
467
+
468
+ def make_inputs_require_grad(module, input, output):
469
+ output.requires_grad_(True)
470
+
471
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
472
+
473
+ # get peft model with the given config
474
+ model = model
475
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
476
+ peft_module_casting_to_bf16(model)
477
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
478
+ self._peft_has_been_casted_to_bf16 = True
479
+
480
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
481
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
482
+ # fail or completely fail.
483
+ elif getattr(args, "gradient_checkpointing", False):
484
+ # For backward compatibility with older versions of transformers
485
+ if hasattr(model, "enable_input_require_grads"):
486
+ model.enable_input_require_grads()
487
+ else:
488
+
489
+ def make_inputs_require_grad(module, input, output):
490
+ output.requires_grad_(True)
491
+
492
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
493
+
494
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
495
+ raise ValueError(
496
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
497
+ " Please install `wandb` or `comet-ml` to resolve."
498
+ )
499
+
500
+ if model is not None:
501
+ self.is_encoder_decoder = model.config.is_encoder_decoder
502
+ elif args.is_encoder_decoder is None:
503
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
504
+ else:
505
+ self.is_encoder_decoder = args.is_encoder_decoder
506
+
507
+ if self.is_encoder_decoder:
508
+ self.decoder_start_token_id = model.config.decoder_start_token_id
509
+ self.pad_token_id = model.config.pad_token_id
510
+
511
+ if processing_class is None:
512
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
513
+ if args.max_length is None:
514
+ warnings.warn(
515
+ "`max_length` is not set in the ORPOConfig's init"
516
+ " it will default to `512` by default, but you should do it yourself in the future.",
517
+ UserWarning,
518
+ )
519
+ max_length = 512
520
+ else:
521
+ max_length = args.max_length
522
+ if args.max_prompt_length is None:
523
+ warnings.warn(
524
+ "`max_prompt_length` is not set in the ORPOConfig's init"
525
+ " it will default to `128` by default, but you should do it yourself in the future.",
526
+ UserWarning,
527
+ )
528
+ max_prompt_length = 128
529
+ else:
530
+ max_prompt_length = args.max_prompt_length
531
+
532
+ if args.max_completion_length is None and self.is_encoder_decoder:
533
+ warnings.warn(
534
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
535
+ " it will default to `128` by default, but you should do it yourself in the future.",
536
+ UserWarning,
537
+ )
538
+ self.max_completion_length = 128
539
+ else:
540
+ self.max_completion_length = args.max_completion_length
541
+
542
+ if data_collator is None:
543
+ data_collator = DPODataCollatorWithPadding(
544
+ pad_token_id=processing_class.pad_token_id,
545
+ label_pad_token_id=args.label_pad_token_id,
546
+ is_encoder_decoder=self.is_encoder_decoder,
547
+ )
548
+
549
+ if args.remove_unused_columns:
550
+ args.remove_unused_columns = False
551
+ # warn users
552
+ warnings.warn(
553
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
554
+ " we have set it for you, but you should do it yourself in the future.",
555
+ UserWarning,
556
+ )
557
+
558
+ self.use_dpo_data_collator = True
559
+ else:
560
+ self.use_dpo_data_collator = False
561
+
562
+ # Disable dropout in the model and reference model
563
+ if args.disable_dropout:
564
+ disable_dropout_in_model(model)
565
+
566
+ self.max_length = max_length
567
+ self.generate_during_eval = args.generate_during_eval
568
+ self.label_pad_token_id = args.label_pad_token_id
569
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
570
+ self.max_prompt_length = max_prompt_length
571
+ self.truncation_mode = args.truncation_mode
572
+ self.processing_class = processing_class
573
+
574
+ self.beta = args.beta
575
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
576
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
577
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
578
+ warnings.warn(
579
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
580
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
581
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
582
+ "loss.",
583
+ UserWarning,
584
+ )
585
+
586
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
587
+
588
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
589
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
590
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
591
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
592
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
593
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
594
+ # that the warning has already been issued.
595
+ model.warnings_issued["estimate_tokens"] = True
596
+
597
+ # Compute that only on the main process for faster data processing.
598
+ # see: https://github.com/huggingface/trl/pull/1255
599
+ with PartialState().local_main_process_first():
600
+ # Extract the prompt if needed, and apply the chat template if needed
601
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
602
+ train_dataset = train_dataset.map(
603
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
604
+ )
605
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
606
+ if eval_dataset is not None:
607
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
608
+ eval_dataset = eval_dataset.map(
609
+ maybe_apply_chat_template,
610
+ fn_kwargs={"tokenizer": processing_class},
611
+ num_proc=args.dataset_num_proc,
612
+ )
613
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
614
+
615
+ super().__init__(
616
+ model=model,
617
+ args=args,
618
+ data_collator=data_collator,
619
+ train_dataset=train_dataset,
620
+ eval_dataset=eval_dataset,
621
+ processing_class=processing_class,
622
+ model_init=model_init,
623
+ compute_metrics=compute_metrics,
624
+ callbacks=callbacks,
625
+ optimizers=optimizers,
626
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
627
+ )
628
+
629
+ # Add tags for models that have been loaded with the correct transformers version
630
+ if hasattr(self.model, "add_model_tags"):
631
+ self.model.add_model_tags(self._tag_names)
632
+
633
+ if not hasattr(self, "accelerator"):
634
+ raise AttributeError(
635
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
636
+ )
637
+
638
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
639
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
640
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
641
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
642
+
643
+ if model is not None:
644
+ if hasattr(model, "config"):
645
+ hidden_size = (
646
+ max(model.config.hidden_sizes)
647
+ if getattr(model.config, "hidden_sizes", None)
648
+ else getattr(model.config, "hidden_size", None)
649
+ )
650
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
651
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
652
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
653
+ config_kwargs.update(
654
+ {
655
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
656
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
657
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
658
+ }
659
+ )
660
+
661
+ # If ZeRO-3 is used, we shard both the active and reference model.
662
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
663
+ if config_kwargs["zero_optimization"]["stage"] != 3:
664
+ config_kwargs["zero_optimization"]["stage"] = 0
665
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
666
+ model.eval()
667
+ return model
668
+
669
+ def build_tokenized_answer(self, prompt, answer):
670
+ """
671
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
672
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
673
+ Reference:
674
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
675
+ """
676
+
677
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
678
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
679
+
680
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
681
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
682
+
683
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
684
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
685
+
686
+ # Prepare input tokens for token by token comparison
687
+ full_input_ids = np.array(full_tokenized["input_ids"])
688
+
689
+ if len(full_input_ids) != len(full_concat_input_ids):
690
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
691
+
692
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
693
+ # can be merged together when tokenizing prompt+answer. This could result
694
+ # on the last token from the prompt being different when tokenized on its own
695
+ # vs when done as prompt+answer.
696
+ response_token_ids_start_idx = len(prompt_input_ids)
697
+
698
+ # If tokenized prompt is different than both prompt+answer, then it means the
699
+ # last token has changed due to merging.
700
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
701
+ response_token_ids_start_idx -= 1
702
+
703
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
704
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
705
+
706
+ if len(prompt_input_ids) != len(prompt_attention_mask):
707
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
708
+
709
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
710
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
711
+
712
+ return dict(
713
+ prompt_input_ids=prompt_input_ids,
714
+ prompt_attention_mask=prompt_attention_mask,
715
+ input_ids=answer_input_ids,
716
+ attention_mask=answer_attention_mask,
717
+ )
718
+
719
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
720
+ """Tokenize a single row from a ORPO specific dataset.
721
+
722
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
723
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
724
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
725
+
726
+ We also create the labels for the chosen/rejected responses, which are of length equal to
727
+ the sum of the length of the prompt and the chosen/rejected response, with
728
+ label_pad_token_id for the prompt tokens.
729
+ """
730
+ batch = {}
731
+ prompt = feature["prompt"]
732
+ chosen = feature["chosen"]
733
+ rejected = feature["rejected"]
734
+
735
+ if not self.is_encoder_decoder:
736
+ # Check issues below for more details
737
+ # 1. https://github.com/huggingface/trl/issues/907
738
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
739
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
740
+
741
+ if not isinstance(prompt, str):
742
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
743
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
744
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
745
+
746
+ if not isinstance(chosen, str):
747
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
748
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
749
+
750
+ if not isinstance(rejected, str):
751
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
752
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
753
+
754
+ # Last prompt token might get merged by tokenizer and
755
+ # it should not be included for generation if that happens
756
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
757
+
758
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
759
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
760
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
761
+
762
+ for k, v in prompt_tokens.items():
763
+ prompt_tokens[k] = v[:prompt_len_input_ids]
764
+
765
+ # Make sure prompts only have one different token at most an
766
+ # and length only differs by 1 at most
767
+ num_diff_tokens = sum(
768
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
769
+ )
770
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
771
+ if num_diff_tokens > 1 or num_diff_len > 1:
772
+ raise ValueError(
773
+ "Chosen and rejected prompt_input_ids might only differ on the "
774
+ "last token due to tokenizer merge ops."
775
+ )
776
+
777
+ # add BOS token to head of prompt. Avoid adding if it's already there
778
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
779
+ self.processing_class.bos_token_id,
780
+ prompt_len_input_ids,
781
+ prompt_tokens,
782
+ chosen_prompt_len_input_ids,
783
+ chosen_tokens,
784
+ rejected_prompt_len_input_ids,
785
+ rejected_tokens,
786
+ )
787
+
788
+ # add EOS token to end of answer. Avoid adding if it's already there
789
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
790
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
791
+ )
792
+
793
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
794
+
795
+ # if combined sequence is too long, truncate the prompt
796
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
797
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
798
+ if self.truncation_mode == "keep_start":
799
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
800
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
801
+ elif self.truncation_mode == "keep_end":
802
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
803
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
804
+ else:
805
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
806
+
807
+ # if that's still too long, truncate the response
808
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
809
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
810
+ for k in ["input_ids", "attention_mask"]:
811
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
812
+
813
+ # Create labels
814
+ chosen_sequence_tokens = {
815
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
816
+ }
817
+ rejected_sequence_tokens = {
818
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
819
+ }
820
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
821
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
822
+ self.label_pad_token_id
823
+ ] * len(chosen_tokens["prompt_input_ids"])
824
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
825
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
826
+ self.label_pad_token_id
827
+ ] * len(rejected_tokens["prompt_input_ids"])
828
+
829
+ for k, toks in {
830
+ "chosen_": chosen_sequence_tokens,
831
+ "rejected_": rejected_sequence_tokens,
832
+ "": prompt_tokens,
833
+ }.items():
834
+ for type_key, tokens in toks.items():
835
+ if type_key == "token_type_ids":
836
+ continue
837
+ batch[f"{k}{type_key}"] = tokens
838
+
839
+ else:
840
+ chosen_tokens = self.processing_class(
841
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
842
+ )
843
+ rejected_tokens = self.processing_class(
844
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
845
+ )
846
+ prompt_tokens = self.processing_class(
847
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
848
+ )
849
+
850
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
851
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
852
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
853
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
854
+
855
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
856
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
857
+ labels=torch.tensor(batch["rejected_labels"])
858
+ )
859
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
860
+ labels=torch.tensor(batch["chosen_labels"])
861
+ )
862
+
863
+ if is_torch_xla_available():
864
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
865
+ for k in batch:
866
+ if "labels" in k or self.is_encoder_decoder:
867
+ pad_value = self.label_pad_token_id
868
+ elif k.endswith("_input_ids"):
869
+ pad_value = self.padding_value
870
+ elif k.endswith("_attention_mask"):
871
+ pad_value = 0
872
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
873
+ return batch
874
+
875
+ @staticmethod
876
+ def concatenated_inputs(
877
+ batch: dict[str, Union[list, torch.LongTensor]],
878
+ is_encoder_decoder: bool = False,
879
+ label_pad_token_id: int = -100,
880
+ padding_value: int = 0,
881
+ device: Optional[torch.device] = None,
882
+ ) -> dict[str, torch.LongTensor]:
883
+ """Concatenate the chosen and rejected inputs into a single tensor.
884
+
885
+ Args:
886
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
887
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
888
+ label_pad_token_id: The label pad token id.
889
+ padding_value: The padding value to use for the concatenated inputs_ids.
890
+ device: The device for the concatenated inputs.
891
+
892
+ Returns:
893
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
894
+ """
895
+ concatenated_batch = {}
896
+
897
+ if is_encoder_decoder:
898
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
899
+ else:
900
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
901
+
902
+ for k in batch:
903
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
904
+ if "labels" in k or is_encoder_decoder:
905
+ pad_value = label_pad_token_id
906
+ elif k.endswith("_input_ids"):
907
+ pad_value = padding_value
908
+ elif k.endswith("_attention_mask"):
909
+ pad_value = 0
910
+ concatenated_key = k.replace("chosen", "concatenated")
911
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
912
+ for k in batch:
913
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
914
+ if "labels" in k or is_encoder_decoder:
915
+ pad_value = label_pad_token_id
916
+ elif k.endswith("_input_ids"):
917
+ pad_value = padding_value
918
+ elif k.endswith("_attention_mask"):
919
+ pad_value = 0
920
+ concatenated_key = k.replace("rejected", "concatenated")
921
+ concatenated_batch[concatenated_key] = torch.cat(
922
+ (
923
+ concatenated_batch[concatenated_key],
924
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
925
+ ),
926
+ dim=0,
927
+ ).to(device=device)
928
+
929
+ if is_encoder_decoder:
930
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
931
+ concatenated_batch["concatenated_attention_mask"] = (
932
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
933
+ )
934
+
935
+ return concatenated_batch
936
+
937
+ def odds_ratio_loss(
938
+ self,
939
+ policy_chosen_logps: torch.FloatTensor,
940
+ policy_rejected_logps: torch.FloatTensor,
941
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
942
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
943
+
944
+ Args:
945
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
946
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
947
+
948
+ Returns:
949
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
950
+ The losses tensor contains the ORPO loss for each example in the batch.
951
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
952
+ The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
953
+ The `log(sigmoid(log_odds_chosen))` for logging purposes.
954
+ """
955
+
956
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
957
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
958
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
959
+ )
960
+ ratio = F.logsigmoid(log_odds)
961
+ losses = self.beta * ratio
962
+
963
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
964
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
965
+
966
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
967
+
968
+ @staticmethod
969
+ def get_batch_logps(
970
+ logits: torch.FloatTensor,
971
+ labels: torch.LongTensor,
972
+ average_log_prob: bool = False,
973
+ label_pad_token_id: int = -100,
974
+ is_encoder_decoder: bool = False,
975
+ ) -> torch.FloatTensor:
976
+ """Compute the log probabilities of the given labels under the given logits.
977
+
978
+ Args:
979
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
980
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
981
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
982
+ label_pad_token_id: The label pad token id.
983
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
984
+
985
+ Returns:
986
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
987
+ """
988
+ if logits.shape[:-1] != labels.shape:
989
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
990
+
991
+ if not is_encoder_decoder:
992
+ labels = labels[:, 1:].clone()
993
+ logits = logits[:, :-1, :]
994
+ loss_mask = labels != label_pad_token_id
995
+
996
+ # dummy token; we'll ignore the losses on these tokens later
997
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
998
+
999
+ per_token_logps = selective_log_softmax(logits, labels)
1000
+
1001
+ if average_log_prob:
1002
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1003
+ else:
1004
+ return (per_token_logps * loss_mask).sum(-1)
1005
+
1006
+ def concatenated_forward(
1007
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1008
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1009
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1010
+
1011
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1012
+ """
1013
+ concatenated_batch = self.concatenated_inputs(
1014
+ batch,
1015
+ is_encoder_decoder=self.is_encoder_decoder,
1016
+ label_pad_token_id=self.label_pad_token_id,
1017
+ padding_value=self.padding_value,
1018
+ device=self.accelerator.device,
1019
+ )
1020
+ len_chosen = batch["chosen_labels"].shape[0]
1021
+
1022
+ model_kwargs = (
1023
+ {
1024
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1025
+ }
1026
+ if self.is_encoder_decoder
1027
+ else {}
1028
+ )
1029
+
1030
+ if self.aux_loss_enabled:
1031
+ model_kwargs["output_router_logits"] = True
1032
+
1033
+ outputs = model(
1034
+ concatenated_batch["concatenated_input_ids"],
1035
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1036
+ use_cache=False,
1037
+ **model_kwargs,
1038
+ )
1039
+ all_logits = outputs.logits
1040
+
1041
+ def cross_entropy_loss(logits, labels):
1042
+ if not self.is_encoder_decoder:
1043
+ # Shift so that tokens < n predict n
1044
+ logits = logits[..., :-1, :].contiguous()
1045
+ labels = labels[..., 1:].contiguous()
1046
+ # Flatten the tokens
1047
+ loss_fct = nn.CrossEntropyLoss()
1048
+ logits = logits.view(-1, logits.shape[-1])
1049
+ labels = labels.view(-1)
1050
+ # Enable model parallelism
1051
+ labels = labels.to(logits.device)
1052
+ loss = loss_fct(logits, labels)
1053
+ return loss
1054
+
1055
+ if self.is_encoder_decoder:
1056
+ labels = concatenated_batch["concatenated_labels"].clone()
1057
+ else:
1058
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1059
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1060
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1061
+ # orpo chosen nll loss is computed over the full prompt and response
1062
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1063
+
1064
+ all_logps = self.get_batch_logps(
1065
+ all_logits,
1066
+ concatenated_batch["concatenated_labels"],
1067
+ average_log_prob=True,
1068
+ is_encoder_decoder=self.is_encoder_decoder,
1069
+ label_pad_token_id=self.label_pad_token_id,
1070
+ )
1071
+
1072
+ chosen_logps = all_logps[:len_chosen]
1073
+ rejected_logps = all_logps[len_chosen:]
1074
+
1075
+ if not self.is_encoder_decoder:
1076
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1077
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1078
+ else:
1079
+ chosen_logits = all_logits[:len_chosen]
1080
+ rejected_logits = all_logits[len_chosen:]
1081
+
1082
+ if self.aux_loss_enabled:
1083
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1084
+
1085
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1086
+
1087
+ def get_batch_loss_metrics(
1088
+ self,
1089
+ model,
1090
+ batch: dict[str, Union[list, torch.LongTensor]],
1091
+ train_eval: Literal["train", "eval"] = "train",
1092
+ ):
1093
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1094
+ metrics = {}
1095
+
1096
+ forward_output = self.concatenated_forward(model, batch)
1097
+ (
1098
+ policy_chosen_logps,
1099
+ policy_rejected_logps,
1100
+ policy_chosen_logits,
1101
+ policy_rejected_logits,
1102
+ policy_nll_loss,
1103
+ ) = forward_output[:5]
1104
+ if self.aux_loss_enabled:
1105
+ aux_loss = forward_output[5]
1106
+
1107
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1108
+ policy_chosen_logps, policy_rejected_logps
1109
+ )
1110
+ # full ORPO loss
1111
+ loss = policy_nll_loss - losses.mean()
1112
+
1113
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1114
+
1115
+ prefix = "eval_" if train_eval == "eval" else ""
1116
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1117
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1118
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1119
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1120
+ chosen_rewards - rejected_rewards
1121
+ ).mean()
1122
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1123
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1124
+ metrics[f"{prefix}logits/rejected"] = (
1125
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1126
+ )
1127
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1128
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1129
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1130
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1131
+ if is_torch_xla_available():
1132
+ xm.mark_step() # needed because .item() calls
1133
+ for k, v in metrics.items():
1134
+ metrics[k] = v.item()
1135
+ if self.aux_loss_enabled:
1136
+ loss += self.aux_loss_coef * aux_loss
1137
+
1138
+ return loss, metrics
1139
+
1140
+ def compute_loss(
1141
+ self,
1142
+ model: Union[PreTrainedModel, nn.Module],
1143
+ inputs: dict[str, Union[torch.Tensor, Any]],
1144
+ return_outputs=False,
1145
+ num_items_in_batch=None,
1146
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1147
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1148
+
1149
+ with compute_loss_context_manager:
1150
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1151
+
1152
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1153
+ loss = loss.to(self.args.device)
1154
+
1155
+ # force log the metrics
1156
+ self.store_metrics(metrics, train_eval="train")
1157
+
1158
+ if return_outputs:
1159
+ return (loss, metrics)
1160
+ return loss
1161
+
1162
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1163
+ """Generate samples from the model and reference model for the given batch of inputs."""
1164
+
1165
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1166
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1167
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1168
+
1169
+ with generate_context_manager:
1170
+ policy_output = model.generate(
1171
+ input_ids=batch["prompt_input_ids"],
1172
+ attention_mask=batch["prompt_attention_mask"],
1173
+ max_length=self.max_length,
1174
+ do_sample=True,
1175
+ pad_token_id=self.processing_class.pad_token_id,
1176
+ )
1177
+
1178
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1179
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1180
+
1181
+ return policy_output_decoded
1182
+
1183
+ def prediction_step(
1184
+ self,
1185
+ model: Union[PreTrainedModel, nn.Module],
1186
+ inputs: dict[str, Union[torch.Tensor, Any]],
1187
+ prediction_loss_only: bool,
1188
+ ignore_keys: Optional[list[str]] = None,
1189
+ ):
1190
+ if not self.use_dpo_data_collator:
1191
+ warnings.warn(
1192
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1193
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1194
+ )
1195
+ if ignore_keys is None:
1196
+ if hasattr(model, "config"):
1197
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1198
+ else:
1199
+ ignore_keys = []
1200
+
1201
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1202
+
1203
+ with torch.no_grad(), prediction_context_manager:
1204
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1205
+
1206
+ # force log the metrics
1207
+ self.store_metrics(metrics, train_eval="eval")
1208
+
1209
+ if prediction_loss_only:
1210
+ return (loss.detach(), None, None)
1211
+
1212
+ # logits for the chosen and rejected samples from model
1213
+ logits_dict = {
1214
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1215
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1216
+ }
1217
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1218
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1219
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1220
+
1221
+ return (loss.detach(), logits, labels)
1222
+
1223
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1224
+ for key, value in metrics.items():
1225
+ self._stored_metrics[train_eval][key].append(value)
1226
+
1227
+ def evaluation_loop(
1228
+ self,
1229
+ dataloader: DataLoader,
1230
+ description: str,
1231
+ prediction_loss_only: Optional[bool] = None,
1232
+ ignore_keys: Optional[list[str]] = None,
1233
+ metric_key_prefix: str = "eval",
1234
+ ) -> EvalLoopOutput:
1235
+ """
1236
+ Overriding built-in evaluation loop to store metrics for each batch.
1237
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1238
+
1239
+ Works both with or without labels.
1240
+ """
1241
+
1242
+ # Sample and save to game log if requested (for one batch to save time)
1243
+ if self.generate_during_eval:
1244
+ # Generate random indices within the range of the total number of samples
1245
+ num_samples = len(dataloader.dataset)
1246
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1247
+
1248
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1249
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1250
+ random_batch = self.data_collator(random_batch_dataset)
1251
+ random_batch = self._prepare_inputs(random_batch)
1252
+
1253
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1254
+
1255
+ table = pd.DataFrame(
1256
+ columns=["Prompt", "Policy"],
1257
+ data=[
1258
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1259
+ ],
1260
+ )
1261
+ if "wandb" in self.args.report_to:
1262
+ wandb.log({"game_log": wandb.Table(data=table)})
1263
+
1264
+ if "comet_ml" in self.args.report_to:
1265
+ log_table_to_comet_experiment(
1266
+ name="game_log.csv",
1267
+ table=table,
1268
+ )
1269
+
1270
+ # Base evaluation
1271
+ initial_output = super().evaluation_loop(
1272
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1273
+ )
1274
+
1275
+ return initial_output
1276
+
1277
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1278
+ """
1279
+ Log `logs` on the various objects watching training, including stored metrics.
1280
+
1281
+ Args:
1282
+ logs (`dict[str, float]`):
1283
+ The values to log.
1284
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1285
+ Start time of the training.
1286
+ """
1287
+ # logs either has 'loss' or 'eval_loss'
1288
+ train_eval = "train" if "loss" in logs else "eval"
1289
+ # Add averaged stored metrics to logs
1290
+ for key, metrics in self._stored_metrics[train_eval].items():
1291
+ logs[key] = torch.tensor(metrics).mean().item()
1292
+ del self._stored_metrics[train_eval]
1293
+
1294
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1295
+ return super().log(logs, start_time)
1296
+ else: # transformers<=4.46
1297
+ return super().log(logs)
1298
+
1299
+ def _shift_right(self, input_ids):
1300
+ if self.decoder_start_token_id is None:
1301
+ raise ValueError(
1302
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1303
+ )
1304
+
1305
+ # shift inputs to the right
1306
+ if is_torch_fx_proxy(input_ids):
1307
+ # Item assignment is not supported natively for proxies.
1308
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1309
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1310
+ else:
1311
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1312
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1313
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1314
+
1315
+ if self.pad_token_id is None:
1316
+ raise ValueError("model.config.pad_token_id has to be defined.")
1317
+ # replace possible -100 values in labels by `pad_token_id`
1318
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1319
+
1320
+ return shifted_input_ids
1321
+
1322
+ def create_model_card(
1323
+ self,
1324
+ model_name: Optional[str] = None,
1325
+ dataset_name: Optional[str] = None,
1326
+ tags: Union[str, list[str], None] = None,
1327
+ ):
1328
+ """
1329
+ Creates a draft of a model card using the information available to the `Trainer`.
1330
+
1331
+ Args:
1332
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1333
+ Name of the model.
1334
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1335
+ Name of the dataset used for training.
1336
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1337
+ Tags to be associated with the model card.
1338
+ """
1339
+ if not self.is_world_process_zero():
1340
+ return
1341
+
1342
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1343
+ base_model = self.model.config._name_or_path
1344
+ else:
1345
+ base_model = None
1346
+
1347
+ tags = tags or []
1348
+ if isinstance(tags, str):
1349
+ tags = [tags]
1350
+
1351
+ if hasattr(self.model.config, "unsloth_version"):
1352
+ tags.append("unsloth")
1353
+
1354
+ citation = textwrap.dedent("""\
1355
+ @article{hong2024orpo,
1356
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1357
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
1358
+ year = 2024,
1359
+ eprint = {arXiv:2403.07691}
1360
+ }""")
1361
+
1362
+ model_card = generate_model_card(
1363
+ base_model=base_model,
1364
+ model_name=model_name,
1365
+ hub_model_id=self.hub_model_id,
1366
+ dataset_name=dataset_name,
1367
+ tags=tags,
1368
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1369
+ comet_url=get_comet_experiment_url(),
1370
+ trainer_name="ORPO",
1371
+ trainer_citation=citation,
1372
+ paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1373
+ paper_id="2403.07691",
1374
+ )
1375
+
1376
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1377
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1378
+ """
1379
+
1380
+ Initialize ORPOTrainer.
1381
+
1382
+ Args:
1383
+ model (`transformers.PreTrainedModel`):
1384
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1385
+ args (`ORPOConfig`):
1386
+ The ORPO config arguments to use for training.
1387
+ data_collator (`transformers.DataCollator`):
1388
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1389
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1390
+ train_dataset (`datasets.Dataset`):
1391
+ The dataset to use for training.
1392
+ eval_dataset (`datasets.Dataset`):
1393
+ The dataset to use for evaluation.
1394
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1395
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1396
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1397
+ reuse the fine-tuned model.
1398
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1399
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1400
+ callbacks (`list[transformers.TrainerCallback]`):
1401
+ The callbacks to use for training.
1402
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1403
+ The optimizer and scheduler to use for training.
1404
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1405
+ The function to use to preprocess the logits before computing the metrics.
1406
+ peft_config (`dict`, defaults to `None`):
1407
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1408
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1409
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1410
+ a dictionary string to metric values.
1411
+
1412
+ """
1413
+ def __init__(
1414
+ self,
1415
+ model = None,
1416
+ args = None,
1417
+ data_collator = None,
1418
+ train_dataset = None,
1419
+ eval_dataset = None,
1420
+ processing_class = None,
1421
+ model_init = None,
1422
+ callbacks = None,
1423
+ preprocess_logits_for_metrics = None,
1424
+ peft_config = None,
1425
+ compute_metrics = None,
1426
+ **kwargs
1427
+ ):
1428
+ if args is None: args = UnslothORPOConfig()
1429
+ use_bf16 = getattr(args, 'bf16', False)
1430
+ use_fp16 = getattr(args, 'fp16', False)
1431
+ force_float32 = False
1432
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1433
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1434
+ force_float32 = True
1435
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1436
+ dtype = getattr(model.config, 'torch_dtype', None)
1437
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1438
+ from unsloth_zoo.utils import _get_dtype
1439
+ dtype = _get_dtype(dtype)
1440
+ float16 = dtype == torch.float16
1441
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1442
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1443
+ if force_float32:
1444
+ args.fp16 = False
1445
+ args.bf16 = False
1446
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1447
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1448
+ args.fp16 = float16
1449
+ args.bf16 = not float16
1450
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1451
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1452
+ args.eval_strategy = 'steps'
1453
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1454
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1455
+ if ga_steps is not None and ga_steps > 1:
1456
+ from transformers import __version__ as transformers_version
1457
+ if Version(transformers_version) <= Version('4.45.2'):
1458
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1459
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1460
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1461
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1462
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1463
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1464
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1465
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1466
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1467
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1468
+ if force_float32:
1469
+ args.bf16_full_eval = False
1470
+ args.fp16_full_eval = False
1471
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1472
+ args.bf16_full_eval = True
1473
+ args.fp16_full_eval = False
1474
+ elif not bf16_full_eval and not fp16_full_eval:
1475
+ args.bf16_full_eval = args.bf16
1476
+ args.fp16_full_eval = args.fp16
1477
+ _output_logits = False
1478
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1479
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1480
+ if _output_logits:
1481
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1482
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1483
+ pass
1484
+ else:
1485
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1486
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1487
+ if args_max_seq_length is None and model_max_seq_length is not None:
1488
+ max_seq_length = model.max_seq_length
1489
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1490
+ if model is not None and hasattr(model, 'for_training'):
1491
+ model.for_training()
1492
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1493
+ if 'processing_class' in locals():
1494
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1495
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1496
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1497
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1498
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1499
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1500
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1501
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1502
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1503
+ else:
1504
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1505
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1506
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1507
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1508
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1509
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1510
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1511
+ else:
1512
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1513
+ other_metrics = []
1514
+
1515
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1516
+ PatchRLStatistics('orpo_trainer', other_metrics)
1517
+
1518
+ super().__init__(
1519
+ model = model,
1520
+ args = args,
1521
+ data_collator = data_collator,
1522
+ train_dataset = train_dataset,
1523
+ eval_dataset = eval_dataset,
1524
+ processing_class = processing_class,
1525
+ model_init = model_init,
1526
+ callbacks = callbacks,
1527
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1528
+ peft_config = peft_config,
1529
+ compute_metrics = compute_metrics,**kwargs)
1530
+ if hasattr(self, 'neftune_hook_handle'):
1531
+ self.neftune_hook_handle.remove()
1532
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1533
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1534
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1535
+ pass
1536
+
1537
+ pass
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
@@ -0,0 +1,1263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ def vLLMSamplingParams(**kwargs):
43
+ from vllm import SamplingParams
44
+ sampling_params = SamplingParams(**kwargs)
45
+ sampling_params._set_kwargs = kwargs
46
+ return sampling_params
47
+ @dataclass
48
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
49
+ """
50
+
51
+ Configuration class for the [`OnlineDPOTrainer`].
52
+
53
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
54
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
55
+ command line.
56
+
57
+ Parameters:
58
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
59
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
60
+ [`~transformers.TrainingArguments`].
61
+ reward_model_path (`str` or `None`, *optional*, defaults to `None`):
62
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
63
+ judge (`str` or `None`, *optional*, defaults to `None`):
64
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
65
+ max_new_tokens (`int`, *optional*, defaults to `64`):
66
+ Maximum number of tokens to generate per completion.
67
+ max_length (`int`, *optional*, defaults to `256`):
68
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
69
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
70
+ possible.
71
+ temperature (`float`, *optional*, defaults to `0.9`):
72
+ Temperature for sampling. The higher the temperature, the more random the completions.
73
+ missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
74
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
75
+ to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
76
+ value.
77
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
78
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
79
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
80
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
81
+ selected for each new epoch and the last β is used for the rest of the epochs.
82
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
83
+ Type of loss to use. Possible values are:
84
+
85
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
86
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
+
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ disable_dropout (`bool`, *optional*, defaults to `True`):
91
+ Whether to disable dropout in the model and reference model.
92
+ use_vllm (`bool`, *optional*, defaults to `False`):
93
+ Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
94
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
95
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
96
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
97
+ capacity of a single GPU, albeit at the cost of slower generation.
98
+
99
+ """
100
+ vllm_sampling_params: Optional[Any] = field(
101
+ default = None,
102
+ metadata = {'help': 'vLLM SamplingParams'},
103
+ )
104
+ unsloth_num_chunks : Optional[int] = field(
105
+ default = -1,
106
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
107
+ )
108
+ def __init__(
109
+ self,
110
+ output_dir = None,
111
+ overwrite_output_dir = None,
112
+ do_train = False,
113
+ do_eval = False,
114
+ do_predict = False,
115
+ eval_strategy = 'no',
116
+ prediction_loss_only = False,
117
+ per_device_train_batch_size = 4,
118
+ per_device_eval_batch_size = 4,
119
+ per_gpu_train_batch_size = None,
120
+ per_gpu_eval_batch_size = None,
121
+ gradient_accumulation_steps = 2,
122
+ eval_accumulation_steps = 2,
123
+ eval_delay = 0,
124
+ torch_empty_cache_steps = 250,
125
+ learning_rate = 5e-05,
126
+ weight_decay = 0.01,
127
+ adam_beta1 = 0.9,
128
+ adam_beta2 = 0.999,
129
+ adam_epsilon = 1e-08,
130
+ max_grad_norm = 1.0,
131
+ num_train_epochs = 3.0,
132
+ max_steps = -1,
133
+ lr_scheduler_type = 'linear',
134
+ warmup_ratio = 0.1,
135
+ warmup_steps = 0,
136
+ log_level = 'passive',
137
+ log_level_replica = 'warning',
138
+ log_on_each_node = True,
139
+ logging_dir = None,
140
+ logging_strategy = 'steps',
141
+ logging_first_step = False,
142
+ logging_steps = 1,
143
+ logging_nan_inf_filter = False,
144
+ save_strategy = 'steps',
145
+ save_steps = 500,
146
+ save_total_limit = None,
147
+ save_safetensors = True,
148
+ save_on_each_node = False,
149
+ save_only_model = False,
150
+ restore_callback_states_from_checkpoint = False,
151
+ no_cuda = False,
152
+ use_cpu = False,
153
+ use_mps_device = False,
154
+ seed = 3407,
155
+ data_seed = 3407,
156
+ jit_mode_eval = False,
157
+ use_ipex = False,
158
+ bf16 = False,
159
+ fp16 = False,
160
+ fp16_opt_level = 'O1',
161
+ half_precision_backend = 'auto',
162
+ bf16_full_eval = False,
163
+ fp16_full_eval = False,
164
+ tf32 = None,
165
+ local_rank = -1,
166
+ ddp_backend = None,
167
+ tpu_num_cores = None,
168
+ tpu_metrics_debug = False,
169
+ debug = '',
170
+ dataloader_drop_last = False,
171
+ eval_steps = None,
172
+ dataloader_num_workers = 0,
173
+ dataloader_prefetch_factor = None,
174
+ past_index = -1,
175
+ run_name = None,
176
+ disable_tqdm = None,
177
+ remove_unused_columns = True,
178
+ label_names = None,
179
+ load_best_model_at_end = False,
180
+ metric_for_best_model = None,
181
+ greater_is_better = None,
182
+ ignore_data_skip = False,
183
+ fsdp = '',
184
+ fsdp_min_num_params = 0,
185
+ fsdp_config = None,
186
+ tp_size = 0,
187
+ fsdp_transformer_layer_cls_to_wrap = None,
188
+ accelerator_config = None,
189
+ deepspeed = None,
190
+ label_smoothing_factor = 0.0,
191
+ optim = 'adamw_8bit',
192
+ optim_args = None,
193
+ adafactor = False,
194
+ group_by_length = False,
195
+ length_column_name = 'length',
196
+ report_to = None,
197
+ ddp_find_unused_parameters = None,
198
+ ddp_bucket_cap_mb = None,
199
+ ddp_broadcast_buffers = None,
200
+ dataloader_pin_memory = True,
201
+ dataloader_persistent_workers = False,
202
+ skip_memory_metrics = True,
203
+ use_legacy_prediction_loop = False,
204
+ push_to_hub = False,
205
+ resume_from_checkpoint = None,
206
+ hub_model_id = None,
207
+ hub_strategy = 'every_save',
208
+ hub_token = None,
209
+ hub_private_repo = None,
210
+ hub_always_push = False,
211
+ gradient_checkpointing = False,
212
+ gradient_checkpointing_kwargs = None,
213
+ include_inputs_for_metrics = False,
214
+ eval_do_concat_batches = True,
215
+ fp16_backend = 'auto',
216
+ push_to_hub_model_id = None,
217
+ push_to_hub_organization = None,
218
+ push_to_hub_token = None,
219
+ mp_parameters = '',
220
+ auto_find_batch_size = False,
221
+ full_determinism = False,
222
+ torchdynamo = None,
223
+ ray_scope = 'last',
224
+ ddp_timeout = 1800,
225
+ torch_compile = False,
226
+ torch_compile_backend = None,
227
+ torch_compile_mode = None,
228
+ include_tokens_per_second = False,
229
+ include_num_input_tokens_seen = False,
230
+ neftune_noise_alpha = None,
231
+ optim_target_modules = None,
232
+ batch_eval_metrics = False,
233
+ eval_on_start = False,
234
+ use_liger_kernel = False,
235
+ eval_use_gather_object = False,
236
+ average_tokens_across_devices = False,
237
+ reward_model_path = None,
238
+ judge = None,
239
+ max_new_tokens = 64,
240
+ max_length = 512,
241
+ temperature = 0.9,
242
+ missing_eos_penalty = None,
243
+ loss_type = 'sigmoid',
244
+ dataset_num_proc = None,
245
+ disable_dropout = True,
246
+ use_vllm = False,
247
+ ds3_gather_for_generation = True,
248
+ vllm_sampling_params = None,
249
+ unsloth_num_chunks = -1,
250
+ **kwargs,
251
+ ):
252
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
253
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
254
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
255
+ output_dir = 'unsloth_training_checkpoints'
256
+ save_strategy = 'no'
257
+ if dataset_num_proc is None:
258
+ from multiprocessing import cpu_count
259
+ dataset_num_proc = cpu_count()
260
+
261
+ super().__init__(
262
+ output_dir = output_dir,
263
+ overwrite_output_dir = overwrite_output_dir,
264
+ do_train = do_train,
265
+ do_eval = do_eval,
266
+ do_predict = do_predict,
267
+ eval_strategy = eval_strategy,
268
+ prediction_loss_only = prediction_loss_only,
269
+ per_device_train_batch_size = per_device_train_batch_size,
270
+ per_device_eval_batch_size = per_device_eval_batch_size,
271
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
272
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
273
+ gradient_accumulation_steps = gradient_accumulation_steps,
274
+ eval_accumulation_steps = eval_accumulation_steps,
275
+ eval_delay = eval_delay,
276
+ torch_empty_cache_steps = torch_empty_cache_steps,
277
+ learning_rate = learning_rate,
278
+ weight_decay = weight_decay,
279
+ adam_beta1 = adam_beta1,
280
+ adam_beta2 = adam_beta2,
281
+ adam_epsilon = adam_epsilon,
282
+ max_grad_norm = max_grad_norm,
283
+ num_train_epochs = num_train_epochs,
284
+ max_steps = max_steps,
285
+ lr_scheduler_type = lr_scheduler_type,
286
+ warmup_ratio = warmup_ratio,
287
+ warmup_steps = warmup_steps,
288
+ log_level = log_level,
289
+ log_level_replica = log_level_replica,
290
+ log_on_each_node = log_on_each_node,
291
+ logging_dir = logging_dir,
292
+ logging_strategy = logging_strategy,
293
+ logging_first_step = logging_first_step,
294
+ logging_steps = logging_steps,
295
+ logging_nan_inf_filter = logging_nan_inf_filter,
296
+ save_strategy = save_strategy,
297
+ save_steps = save_steps,
298
+ save_total_limit = save_total_limit,
299
+ save_safetensors = save_safetensors,
300
+ save_on_each_node = save_on_each_node,
301
+ save_only_model = save_only_model,
302
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
303
+ no_cuda = no_cuda,
304
+ use_cpu = use_cpu,
305
+ use_mps_device = use_mps_device,
306
+ seed = seed,
307
+ data_seed = data_seed,
308
+ jit_mode_eval = jit_mode_eval,
309
+ use_ipex = use_ipex,
310
+ bf16 = bf16,
311
+ fp16 = fp16,
312
+ fp16_opt_level = fp16_opt_level,
313
+ half_precision_backend = half_precision_backend,
314
+ bf16_full_eval = bf16_full_eval,
315
+ fp16_full_eval = fp16_full_eval,
316
+ tf32 = tf32,
317
+ local_rank = local_rank,
318
+ ddp_backend = ddp_backend,
319
+ tpu_num_cores = tpu_num_cores,
320
+ tpu_metrics_debug = tpu_metrics_debug,
321
+ debug = debug,
322
+ dataloader_drop_last = dataloader_drop_last,
323
+ eval_steps = eval_steps,
324
+ dataloader_num_workers = dataloader_num_workers,
325
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
326
+ past_index = past_index,
327
+ run_name = run_name,
328
+ disable_tqdm = disable_tqdm,
329
+ remove_unused_columns = remove_unused_columns,
330
+ label_names = label_names,
331
+ load_best_model_at_end = load_best_model_at_end,
332
+ metric_for_best_model = metric_for_best_model,
333
+ greater_is_better = greater_is_better,
334
+ ignore_data_skip = ignore_data_skip,
335
+ fsdp = fsdp,
336
+ fsdp_min_num_params = fsdp_min_num_params,
337
+ fsdp_config = fsdp_config,
338
+ tp_size = tp_size,
339
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
340
+ accelerator_config = accelerator_config,
341
+ deepspeed = deepspeed,
342
+ label_smoothing_factor = label_smoothing_factor,
343
+ optim = optim,
344
+ optim_args = optim_args,
345
+ adafactor = adafactor,
346
+ group_by_length = group_by_length,
347
+ length_column_name = length_column_name,
348
+ report_to = report_to,
349
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
350
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
351
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
352
+ dataloader_pin_memory = dataloader_pin_memory,
353
+ dataloader_persistent_workers = dataloader_persistent_workers,
354
+ skip_memory_metrics = skip_memory_metrics,
355
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
356
+ push_to_hub = push_to_hub,
357
+ resume_from_checkpoint = resume_from_checkpoint,
358
+ hub_model_id = hub_model_id,
359
+ hub_strategy = hub_strategy,
360
+ hub_token = hub_token,
361
+ hub_private_repo = hub_private_repo,
362
+ hub_always_push = hub_always_push,
363
+ gradient_checkpointing = gradient_checkpointing,
364
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
365
+ include_inputs_for_metrics = include_inputs_for_metrics,
366
+ eval_do_concat_batches = eval_do_concat_batches,
367
+ fp16_backend = fp16_backend,
368
+ push_to_hub_model_id = push_to_hub_model_id,
369
+ push_to_hub_organization = push_to_hub_organization,
370
+ push_to_hub_token = push_to_hub_token,
371
+ mp_parameters = mp_parameters,
372
+ auto_find_batch_size = auto_find_batch_size,
373
+ full_determinism = full_determinism,
374
+ torchdynamo = torchdynamo,
375
+ ray_scope = ray_scope,
376
+ ddp_timeout = ddp_timeout,
377
+ torch_compile = torch_compile,
378
+ torch_compile_backend = torch_compile_backend,
379
+ torch_compile_mode = torch_compile_mode,
380
+ include_tokens_per_second = include_tokens_per_second,
381
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
382
+ neftune_noise_alpha = neftune_noise_alpha,
383
+ optim_target_modules = optim_target_modules,
384
+ batch_eval_metrics = batch_eval_metrics,
385
+ eval_on_start = eval_on_start,
386
+ use_liger_kernel = use_liger_kernel,
387
+ eval_use_gather_object = eval_use_gather_object,
388
+ average_tokens_across_devices = average_tokens_across_devices,
389
+ reward_model_path = reward_model_path,
390
+ judge = judge,
391
+ max_new_tokens = max_new_tokens,
392
+ max_length = max_length,
393
+ temperature = temperature,
394
+ missing_eos_penalty = missing_eos_penalty,
395
+ loss_type = loss_type,
396
+ dataset_num_proc = dataset_num_proc,
397
+ disable_dropout = disable_dropout,
398
+ use_vllm = use_vllm,
399
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
400
+ self.vllm_sampling_params = vllm_sampling_params
401
+ self.unsloth_num_chunks = unsloth_num_chunks
402
+ pass
403
+
404
+ class _UnslothOnlineDPOTrainer(Trainer):
405
+ r""""""
406
+
407
+ _tag_names = ["trl", "online-dpo"]
408
+
409
+ def __init__(
410
+ self,
411
+ model: Union[PreTrainedModel, nn.Module],
412
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
413
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
414
+ judge: Optional[BasePairwiseJudge] = None,
415
+ args: Optional[OnlineDPOConfig] = None,
416
+ data_collator: Optional[DataCollator] = None,
417
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
418
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
419
+ processing_class: Optional[
420
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
421
+ ] = None,
422
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
423
+ peft_config: Optional[dict] = None,
424
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
425
+ callbacks: Optional[list[TrainerCallback]] = None,
426
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
427
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
428
+ ) -> None:
429
+
430
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
431
+ if ref_model is model:
432
+ raise ValueError(
433
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
434
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
435
+ )
436
+
437
+ self.ref_model = ref_model
438
+
439
+ if reward_model is not None and judge is not None:
440
+ warnings.warn(
441
+ "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
442
+ "Ignoring `judge` and using `reward_model`.",
443
+ UserWarning,
444
+ )
445
+ judge = None
446
+ elif reward_model is None and judge is None:
447
+ raise ValueError("Either `reward_model` or `judge` must be provided.")
448
+
449
+ self.reward_model = reward_model
450
+ self.reward_processing_class = reward_processing_class
451
+ self.judge = judge
452
+
453
+ if args.missing_eos_penalty is not None and judge is not None:
454
+ raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
455
+
456
+ if args is None:
457
+ raise ValueError("`args` must be provided.")
458
+
459
+ # Check that the processing_class is provided
460
+ if processing_class is None:
461
+ raise ValueError("`processing_class` must be provided.")
462
+
463
+ # Convert to PEFT model if peft_config is provided
464
+ if False:
465
+ # Check if PEFT is available
466
+ if not is_peft_available():
467
+ raise ImportError(
468
+ "PEFT is not available and passed `peft_config`. Please install PEFT with "
469
+ "`pip install peft` to use it."
470
+ )
471
+
472
+ # If the model is already a PeftModel, we need to merge and unload it.
473
+ # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
474
+ if isinstance(model, PeftModel):
475
+ model = model.merge_and_unload()
476
+
477
+ # Get peft model with the given config
478
+ model = model
479
+
480
+ # Disable dropout in the model and reference model
481
+ if args.disable_dropout:
482
+ disable_dropout_in_model(model)
483
+ if self.ref_model is not None:
484
+ disable_dropout_in_model(self.ref_model)
485
+
486
+ # Handle the ref_model
487
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
488
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
489
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
490
+ if ref_model is None: # No ref model provided, the most common case
491
+ if False:
492
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
493
+ else:
494
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
495
+ else: # rare case, the user provided a ref model
496
+ self.ref_model = ref_model
497
+ self.ref_model.eval()
498
+
499
+ # Disable the gradient and set the reward model in eval mode
500
+ if self.reward_model is not None:
501
+ self.reward_model.eval()
502
+
503
+ # Define the collator is not provided
504
+ if data_collator is None:
505
+ data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
506
+
507
+ self.max_length = args.max_length
508
+
509
+ self.stats = {
510
+ "objective/kl": [],
511
+ "objective/entropy": [],
512
+ "objective/non_score_reward": [],
513
+ "rewards/chosen": [],
514
+ "rewards/rejected": [],
515
+ "rewards/accuracies": [],
516
+ "rewards/margins": [],
517
+ "logps/chosen": [],
518
+ "logps/rejected": [],
519
+ "val/contain_eos_token": [],
520
+ "beta": [],
521
+ }
522
+ if self.reward_model is not None:
523
+ self.stats["objective/rlhf_reward"] = []
524
+ self.stats["objective/scores_margin"] = []
525
+ self.stats["objective/scores"] = []
526
+
527
+ if args.use_vllm:
528
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
529
+ n=2, max_tokens=args.max_new_tokens,
530
+ temperature=args.temperature,
531
+ top_k=50,
532
+ top_p=1.0,
533
+ detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
534
+ else:
535
+ self.generation_config = GenerationConfig(
536
+ max_new_tokens=args.max_new_tokens,
537
+ temperature=args.temperature,
538
+ top_k=50,
539
+ top_p=1.0,
540
+ do_sample=True,
541
+ use_cache=False if args.gradient_checkpointing else True,
542
+ )
543
+
544
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
545
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
546
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
547
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
548
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
549
+ # that the warning has already been issued.
550
+ model.warnings_issued["estimate_tokens"] = True
551
+
552
+ super().__init__(
553
+ model=model,
554
+ args=args,
555
+ data_collator=data_collator,
556
+ train_dataset=train_dataset,
557
+ eval_dataset=eval_dataset,
558
+ processing_class=processing_class,
559
+ compute_metrics=compute_metrics,
560
+ callbacks=callbacks,
561
+ optimizers=optimizers,
562
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
563
+ )
564
+
565
+ # Add tags for models that have been loaded with the correct transformers version
566
+ if hasattr(self.model, "add_model_tags"):
567
+ self.model.add_model_tags(self._tag_names)
568
+
569
+ self._beta = args.beta
570
+
571
+ # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
572
+ if self.is_deepspeed_enabled:
573
+ if self.reward_model is not None:
574
+ self.reward_model = prepare_deepspeed(
575
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
576
+ )
577
+ if self.ref_model is not None:
578
+ self.ref_model = prepare_deepspeed(
579
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
580
+ )
581
+ else:
582
+ if self.ref_model is not None:
583
+ self.ref_model = self.ref_model.to(self.accelerator.device)
584
+ if self.reward_model is not None:
585
+ self.reward_model = self.reward_model.to(self.accelerator.device)
586
+
587
+ @property
588
+ def beta(self):
589
+ if isinstance(self._beta, list):
590
+ epoch = self.state.epoch
591
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
592
+ else:
593
+ return self._beta
594
+
595
+ @staticmethod
596
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
597
+ """Tokenize a single row from a DPO specific dataset."""
598
+ if not is_encoder_decoder:
599
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
600
+ # Add BOS token to head of prompt. Avoid adding if it's already there
601
+ if tokenizer.bos_token_id is not None:
602
+ prompt_len_input_ids = len(batch["input_ids"])
603
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
604
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
605
+ batch["attention_mask"] = [1] + batch["attention_mask"]
606
+ else:
607
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
608
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
609
+ return batch
610
+
611
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
612
+ @wraps(Trainer.get_train_dataloader)
613
+ def get_train_dataloader(self) -> DataLoader:
614
+ if self.train_dataset is None:
615
+ raise ValueError("Trainer: training requires a train_dataset.")
616
+
617
+ train_dataset = self.train_dataset
618
+ data_collator = self.data_collator
619
+ dataloader_params = {
620
+ "batch_size": self._train_batch_size,
621
+ "collate_fn": data_collator,
622
+ "num_workers": self.args.dataloader_num_workers,
623
+ "pin_memory": self.args.dataloader_pin_memory,
624
+ "persistent_workers": self.args.dataloader_persistent_workers,
625
+ }
626
+
627
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
628
+ dataloader_params["sampler"] = self._get_train_sampler()
629
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
630
+ dataloader_params["worker_init_fn"] = seed_worker
631
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
632
+
633
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
634
+
635
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
636
+ @wraps(Trainer.get_eval_dataloader)
637
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
638
+ if eval_dataset is None and self.eval_dataset is None:
639
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
640
+
641
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
642
+ # don't change during training
643
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
644
+ if (
645
+ hasattr(self, "_eval_dataloaders")
646
+ and dataloader_key in self._eval_dataloaders
647
+ and self.args.dataloader_persistent_workers
648
+ ):
649
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
650
+
651
+ eval_dataset = (
652
+ self.eval_dataset[eval_dataset]
653
+ if isinstance(eval_dataset, str)
654
+ else eval_dataset
655
+ if eval_dataset is not None
656
+ else self.eval_dataset
657
+ )
658
+ data_collator = self.data_collator
659
+
660
+ dataloader_params = {
661
+ "batch_size": self.args.eval_batch_size,
662
+ "collate_fn": data_collator,
663
+ "num_workers": self.args.dataloader_num_workers,
664
+ "pin_memory": self.args.dataloader_pin_memory,
665
+ "persistent_workers": self.args.dataloader_persistent_workers,
666
+ }
667
+
668
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
669
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
670
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
671
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
672
+
673
+ # accelerator.free_memory() will destroy the references, so
674
+ # we need to store the non-prepared version
675
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
676
+ if self.args.dataloader_persistent_workers:
677
+ if hasattr(self, "_eval_dataloaders"):
678
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
679
+ else:
680
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
681
+
682
+ return self.accelerator.prepare(eval_dataloader)
683
+
684
+ def _generate_vllm(self, model, prompts):
685
+ eos_token_id = self.processing_class.eos_token_id
686
+ pad_token_id = self.processing_class.pad_token_id
687
+
688
+ # Load the latest weights
689
+
690
+ pass
691
+
692
+ pass
693
+
694
+ if is_conversational({"prompt": prompts[0]}):
695
+ outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
696
+ else:
697
+ outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
698
+
699
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
700
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
701
+
702
+ # Create mask and pad the prompt and completion
703
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
704
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
705
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
706
+ max_tokens = self.generation_config.max_tokens
707
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
708
+ completion_ids = [
709
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
710
+ for ids in completion_ids
711
+ ]
712
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
713
+
714
+ # Convert to tensors
715
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
716
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
717
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
718
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
719
+
720
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
721
+
722
+ def _generate(self, model, prompts):
723
+ eos_token_id = self.processing_class.eos_token_id
724
+ pad_token_id = self.processing_class.pad_token_id
725
+
726
+ # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
727
+ # policies with different tokenizers / chat templates.
728
+ inputs = [{"prompt": prompt} for prompt in prompts]
729
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
730
+ inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
731
+ inputs = self.data_collator(inputs)
732
+
733
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
734
+ inputs = self._prepare_inputs(inputs)
735
+ prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
736
+ prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
737
+ with unwrap_model_for_generation(
738
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
739
+ ) as unwrapped_model:
740
+ output = unwrapped_model.generate(
741
+ input_ids=prompt_ids,
742
+ attention_mask=prompt_mask,
743
+ generation_config=self.generation_config,
744
+ )
745
+
746
+ completion_ids = output[:, prompt_ids.size(1) :]
747
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
748
+
749
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
750
+
751
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
752
+ # Get the number of tokens to truncate from prompt
753
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
754
+
755
+ # Truncate left to avoid oom
756
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
757
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
758
+
759
+ # Concat the prompt and completion
760
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
761
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
762
+
763
+ # Get the logprobs of the completions from the model
764
+ output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
765
+
766
+ # There is 1 offset, because the model predict the next token
767
+ logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
768
+
769
+ # Take the completion tokens logprob
770
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
771
+ return logprobs
772
+
773
+ def training_step(
774
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
775
+ ) -> torch.Tensor:
776
+ model.train()
777
+
778
+ prompts = inputs["prompt"]
779
+ batch_size = len(prompts)
780
+
781
+ if self.args.use_vllm:
782
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
783
+ else:
784
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
785
+
786
+ contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
787
+
788
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
789
+ with torch.no_grad():
790
+ if self.ref_model is not None:
791
+ ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
792
+ else: # peft case: we just need to disable the adapter
793
+ with self.model.disable_adapter():
794
+ ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
795
+
796
+ # Decode the completions, and format them if the input is conversational
797
+ device = logprobs.device
798
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
799
+ if is_conversational({"prompt": prompts[0]}):
800
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
801
+
802
+ # Get the reward from the reward model or judge
803
+ if self.judge is not None:
804
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
805
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
806
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
807
+ # template to it.
808
+ if is_conversational({"prompt": prompts[0]}):
809
+ environment = jinja2.Environment()
810
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
811
+ prompts = [template.render(messages=prompt) for prompt in prompts]
812
+ completions = [template.render(messages=completion) for completion in completions]
813
+
814
+ ranks_of_first_completion = self.judge.judge(
815
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
816
+ )
817
+
818
+ # convert ranks to a True/False mask:
819
+ # when rank == 0, it means the first completion is the best
820
+ # when rank == 1, it means the second completion is the best
821
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
822
+ else:
823
+ # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
824
+ # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
825
+ prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
826
+ if is_conversational({"prompt": prompts[0]}):
827
+ examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
828
+ examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
829
+ prompts = [example["prompt"] for example in examples]
830
+ completions = [example["completion"] for example in examples]
831
+
832
+ # Tokenize the prompts
833
+ prompts_ids = self.reward_processing_class(
834
+ prompts, padding=True, return_tensors="pt", padding_side="left"
835
+ )["input_ids"].to(device)
836
+ context_length = prompts_ids.shape[1]
837
+
838
+ # Tokenize the completions
839
+ completions_ids = self.reward_processing_class(
840
+ completions, padding=True, return_tensors="pt", padding_side="right"
841
+ )["input_ids"].to(device)
842
+
843
+ # Concatenate the prompts and completions and get the reward
844
+ prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
845
+ with torch.inference_mode():
846
+ _, scores, _ = get_reward(
847
+ self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
848
+ )
849
+
850
+ # Filter completion. Ensure that the sample contains stop_token_id
851
+ # Completions not passing that filter will receive a lower score.
852
+ if self.args.missing_eos_penalty is not None:
853
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
854
+
855
+ # Split the scores in 2 (the prompts of the first half are the same as the second half)
856
+ first_half, second_half = scores.split(batch_size)
857
+
858
+ # Get the indices of the chosen and rejected examples
859
+ mask = first_half >= second_half
860
+
861
+ batch_range = torch.arange(batch_size, device=device)
862
+ chosen_indices = batch_range + (~mask * batch_size)
863
+ rejected_indices = batch_range + (mask * batch_size)
864
+
865
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
866
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
867
+ cr_logprobs = logprobs[cr_indices]
868
+ cr_ref_logprobs = ref_logprobs[cr_indices]
869
+
870
+ # mask out the padding tokens
871
+ padding_mask = ~completion_mask.bool()
872
+ cr_padding_mask = padding_mask[cr_indices]
873
+
874
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
875
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
876
+
877
+ # Split the chosen and rejected examples
878
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
879
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
880
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
881
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
882
+
883
+ logits = pi_logratios - ref_logratios
884
+
885
+ if self.args.loss_type == "sigmoid":
886
+ losses = -F.logsigmoid(self.beta * logits)
887
+ elif self.args.loss_type == "ipo":
888
+ losses = (logits - 1 / (2 * self.beta)) ** 2
889
+ else:
890
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
891
+
892
+ loss = losses.mean()
893
+
894
+ # Log everything
895
+ if self.reward_model is not None:
896
+ scores_margin = scores[chosen_indices] - scores[rejected_indices]
897
+ self.stats["objective/scores_margin"].append(
898
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
899
+ )
900
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
901
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
902
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
903
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
904
+
905
+ kl = logprobs - ref_logprobs
906
+ mean_kl = kl.sum(1).mean()
907
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
908
+ non_score_reward = (-self.beta * kl).sum(1)
909
+ mean_non_score_reward = non_score_reward.mean()
910
+ self.stats["objective/non_score_reward"].append(
911
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
912
+ )
913
+ if self.reward_model is not None:
914
+ rlhf_reward = scores + non_score_reward
915
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
916
+ mean_entropy = -logprobs.sum(1).mean()
917
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
918
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
919
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
920
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
921
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
922
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
923
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
924
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
925
+ self.stats["rewards/margins"].append(margin.mean().item())
926
+ accuracy = margin > 0
927
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
928
+ self.stats["beta"].append(self.beta)
929
+
930
+ if (
931
+ self.args.torch_empty_cache_steps is not None
932
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
933
+ ):
934
+ empty_cache()
935
+
936
+ kwargs = {}
937
+
938
+ # For LOMO optimizers you need to explicitly use the learnign rate
939
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
940
+ kwargs["learning_rate"] = self._get_learning_rate()
941
+
942
+ if self.args.n_gpu > 1:
943
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
944
+
945
+ if self.use_apex:
946
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
947
+ scaled_loss.backward()
948
+ else:
949
+ self.accelerator.backward(loss, **kwargs)
950
+
951
+ return loss.detach() / self.args.gradient_accumulation_steps
952
+
953
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
954
+ # start_time defaults to None to allow compatibility with transformers<=4.46
955
+ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
956
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
957
+ logs: dict[str, float] = {}
958
+
959
+ # all_gather + mean() to get average loss over all processes
960
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
961
+
962
+ # reset tr_loss to zero
963
+ tr_loss -= tr_loss
964
+
965
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
966
+ if grad_norm is not None:
967
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
968
+ logs["learning_rate"] = self._get_learning_rate()
969
+
970
+ # Add our metrics
971
+ for key, val in self.stats.items():
972
+ logs[key] = sum(val) / len(val)
973
+ self.stats = {key: [] for key in self.stats} # reset stats
974
+
975
+ self._total_loss_scalar += tr_loss_scalar
976
+ self._globalstep_last_logged = self.state.global_step
977
+ self.store_flos()
978
+
979
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
980
+ self.log(logs, start_time)
981
+ else: # transformers<=4.46
982
+ self.log(logs)
983
+
984
+ metrics = None
985
+ if self.control.should_evaluate:
986
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
987
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
988
+
989
+ if self.args.save_strategy == "best":
990
+ self.control.should_save = is_new_best_metric
991
+
992
+ if self.control.should_save:
993
+ self._save_checkpoint(model, trial)
994
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
995
+
996
+ # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
997
+ # This can be removed once the minimum transformers version is updated to 4.47.
998
+ # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
999
+ def _determine_best_metric(self, metrics, trial):
1000
+ """
1001
+ Determine if the model should be saved based on the evaluation metrics.
1002
+ If args.metric_for_best_model is not set, the loss is used.
1003
+ Returns:
1004
+ bool: True if a new best metric was found, else False
1005
+ """
1006
+ is_new_best_metric = False
1007
+
1008
+ if self.args.metric_for_best_model is not None:
1009
+ metric_to_check = self.args.metric_for_best_model
1010
+
1011
+ if not metric_to_check.startswith("eval_"):
1012
+ metric_to_check = f"eval_{metric_to_check}"
1013
+
1014
+ try:
1015
+ metric_value = metrics[metric_to_check]
1016
+ except KeyError as exc:
1017
+ raise KeyError(
1018
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1019
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1020
+ ) from exc
1021
+
1022
+ operator = np.greater if self.args.greater_is_better else np.less
1023
+
1024
+ if self.state.best_metric is None:
1025
+ self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1026
+
1027
+ if operator(metric_value, self.state.best_metric):
1028
+ run_dir = self._get_output_dir(trial=trial)
1029
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1030
+ output_dir = os.path.join(run_dir, checkpoint_folder)
1031
+ self.state.best_metric = metric_value
1032
+ self.state.best_model_checkpoint = output_dir
1033
+
1034
+ is_new_best_metric = True
1035
+
1036
+ return is_new_best_metric
1037
+
1038
+ def create_model_card(
1039
+ self,
1040
+ model_name: Optional[str] = None,
1041
+ dataset_name: Optional[str] = None,
1042
+ tags: Union[str, list[str], None] = None,
1043
+ ):
1044
+ """
1045
+ Creates a draft of a model card using the information available to the `Trainer`.
1046
+
1047
+ Args:
1048
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1049
+ Name of the model.
1050
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1051
+ Name of the dataset used for training.
1052
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1053
+ Tags to be associated with the model card.
1054
+ """
1055
+ if not self.is_world_process_zero():
1056
+ return
1057
+
1058
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1059
+ base_model = self.model.config._name_or_path
1060
+ else:
1061
+ base_model = None
1062
+
1063
+ tags = tags or []
1064
+ if isinstance(tags, str):
1065
+ tags = [tags]
1066
+
1067
+ if hasattr(self.model.config, "unsloth_version"):
1068
+ tags.append("unsloth")
1069
+
1070
+ citation = textwrap.dedent("""\
1071
+ @article{guo2024direct,
1072
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
1073
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1074
+ year = 2024,
1075
+ eprint = {arXiv:2402.04792}
1076
+ }""")
1077
+
1078
+ model_card = generate_model_card(
1079
+ base_model=base_model,
1080
+ model_name=model_name,
1081
+ hub_model_id=self.hub_model_id,
1082
+ dataset_name=dataset_name,
1083
+ tags=tags,
1084
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1085
+ comet_url=get_comet_experiment_url(),
1086
+ trainer_name="Online DPO",
1087
+ trainer_citation=citation,
1088
+ paper_title="Direct Language Model Alignment from Online AI Feedback",
1089
+ paper_id="2402.04792",
1090
+ )
1091
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1092
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1093
+ """
1094
+
1095
+ Initialize OnlineDPOTrainer.
1096
+
1097
+ Args:
1098
+ model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1099
+ The model to train, preferably an `AutoModelForCausalLM`.
1100
+ ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1101
+ The reference model to use for training. If None is specified, the reference model will be created from
1102
+ the model.
1103
+ reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1104
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1105
+ judge (`BasePairwiseJudge`):
1106
+ The judge to use for pairwise comparison of model completions.
1107
+ args (`OnlineDPOConfig`):
1108
+ The online DPO config arguments to use for training.
1109
+ data_collator (`transformers.DataCollator`):
1110
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1111
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1112
+ train_dataset (`datasets.Dataset`):
1113
+ The dataset to use for training.
1114
+ eval_dataset (`datasets.Dataset`):
1115
+ The dataset to use for evaluation.
1116
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1117
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1118
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1119
+ reuse the fine-tuned model.
1120
+ peft_config (`dict`):
1121
+ The peft config to use for training.
1122
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1123
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1124
+ a dictionary string to metric values.
1125
+ callbacks (`list[transformers.TrainerCallback]`):
1126
+ The callbacks to use for training.
1127
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1128
+ The optimizer and scheduler to use for training.
1129
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1130
+ The function to use to preprocess the logits before computing the metrics.
1131
+
1132
+ """
1133
+ def __init__(
1134
+ self,
1135
+ model,
1136
+ ref_model = None,
1137
+ reward_model = None,
1138
+ judge = None,
1139
+ args = None,
1140
+ data_collator = None,
1141
+ train_dataset = None,
1142
+ eval_dataset = None,
1143
+ processing_class = None,
1144
+ reward_processing_class = None,
1145
+ peft_config = None,
1146
+ compute_metrics = None,
1147
+ callbacks = None,
1148
+ preprocess_logits_for_metrics = None,
1149
+ **kwargs
1150
+ ):
1151
+ if args is None: args = UnslothOnlineDPOConfig()
1152
+ use_bf16 = getattr(args, 'bf16', False)
1153
+ use_fp16 = getattr(args, 'fp16', False)
1154
+ force_float32 = False
1155
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1156
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1157
+ force_float32 = True
1158
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1159
+ dtype = getattr(model.config, 'torch_dtype', None)
1160
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1161
+ from unsloth_zoo.utils import _get_dtype
1162
+ dtype = _get_dtype(dtype)
1163
+ float16 = dtype == torch.float16
1164
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1165
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1166
+ if force_float32:
1167
+ args.fp16 = False
1168
+ args.bf16 = False
1169
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1170
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1171
+ args.fp16 = float16
1172
+ args.bf16 = not float16
1173
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1174
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1175
+ args.eval_strategy = 'steps'
1176
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1177
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1178
+ if ga_steps is not None and ga_steps > 1:
1179
+ from transformers import __version__ as transformers_version
1180
+ if Version(transformers_version) <= Version('4.45.2'):
1181
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1182
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1183
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1184
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1185
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1186
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1187
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1188
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1189
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1190
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1191
+ if force_float32:
1192
+ args.bf16_full_eval = False
1193
+ args.fp16_full_eval = False
1194
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1195
+ args.bf16_full_eval = True
1196
+ args.fp16_full_eval = False
1197
+ elif not bf16_full_eval and not fp16_full_eval:
1198
+ args.bf16_full_eval = args.bf16
1199
+ args.fp16_full_eval = args.fp16
1200
+ _output_logits = False
1201
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1202
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1203
+ if _output_logits:
1204
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1205
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1206
+ pass
1207
+ else:
1208
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1209
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1210
+ if args_max_seq_length is None and model_max_seq_length is not None:
1211
+ max_seq_length = model.max_seq_length
1212
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1213
+ if model is not None and hasattr(model, 'for_training'):
1214
+ model.for_training()
1215
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1216
+ if 'processing_class' in locals():
1217
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1218
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1219
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1220
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1221
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1222
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1223
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1224
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1225
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1226
+ else:
1227
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1228
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1229
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1230
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1231
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1232
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1233
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1234
+ else:
1235
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1236
+ other_metrics = []
1237
+
1238
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1239
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
1240
+
1241
+ super().__init__(
1242
+ model = model,
1243
+ ref_model = ref_model,
1244
+ reward_model = reward_model,
1245
+ judge = judge,
1246
+ args = args,
1247
+ data_collator = data_collator,
1248
+ train_dataset = train_dataset,
1249
+ eval_dataset = eval_dataset,
1250
+ processing_class = processing_class,
1251
+ reward_processing_class = reward_processing_class,
1252
+ peft_config = peft_config,
1253
+ compute_metrics = compute_metrics,
1254
+ callbacks = callbacks,
1255
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1256
+ if hasattr(self, 'neftune_hook_handle'):
1257
+ self.neftune_hook_handle.remove()
1258
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1259
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1260
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1261
+ pass
1262
+
1263
+ pass
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPPOConfig(PPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`PPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
58
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
59
+ ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
60
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
61
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
62
+ Number of epochs to train.
63
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
64
+ Whether to whiten the rewards.
65
+ kl_coef (`float`, *optional*, defaults to `0.05`):
66
+ KL coefficient.
67
+ cliprange (`float`, *optional*, defaults to `0.2`):
68
+ Clip range.
69
+ vf_coef (`float`, *optional*, defaults to `0.1`):
70
+ Value function coefficient.
71
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
72
+ Clip range for the value function.
73
+ gamma (`float`, *optional*, defaults to `1.0`):
74
+ Discount factor.
75
+ lam (`float`, *optional*, defaults to `0.95`):
76
+ Lambda value for GAE.
77
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
78
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
79
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
80
+ capacity of a single GPU, albeit at the cost of slower generation.
81
+
82
+ """
83
+ vllm_sampling_params: Optional[Any] = field(
84
+ default = None,
85
+ metadata = {'help': 'vLLM SamplingParams'},
86
+ )
87
+ unsloth_num_chunks : Optional[int] = field(
88
+ default = -1,
89
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
+ )
91
+ def __init__(
92
+ self,
93
+ output_dir = None,
94
+ overwrite_output_dir = None,
95
+ do_train = False,
96
+ do_eval = False,
97
+ do_predict = False,
98
+ eval_strategy = 'no',
99
+ prediction_loss_only = False,
100
+ per_device_train_batch_size = 4,
101
+ per_device_eval_batch_size = 4,
102
+ per_gpu_train_batch_size = None,
103
+ per_gpu_eval_batch_size = None,
104
+ gradient_accumulation_steps = 2,
105
+ eval_accumulation_steps = 2,
106
+ eval_delay = 0,
107
+ torch_empty_cache_steps = 250,
108
+ learning_rate = 5e-05,
109
+ weight_decay = 0.01,
110
+ adam_beta1 = 0.9,
111
+ adam_beta2 = 0.999,
112
+ adam_epsilon = 1e-08,
113
+ max_grad_norm = 1.0,
114
+ num_train_epochs = 3.0,
115
+ max_steps = -1,
116
+ lr_scheduler_type = 'linear',
117
+ warmup_ratio = 0.1,
118
+ warmup_steps = 0,
119
+ log_level = 'passive',
120
+ log_level_replica = 'warning',
121
+ log_on_each_node = True,
122
+ logging_dir = None,
123
+ logging_strategy = 'steps',
124
+ logging_first_step = False,
125
+ logging_steps = 1,
126
+ logging_nan_inf_filter = False,
127
+ save_strategy = 'steps',
128
+ save_steps = 500,
129
+ save_total_limit = None,
130
+ save_safetensors = True,
131
+ save_on_each_node = False,
132
+ save_only_model = False,
133
+ restore_callback_states_from_checkpoint = False,
134
+ no_cuda = False,
135
+ use_cpu = False,
136
+ use_mps_device = False,
137
+ seed = 3407,
138
+ data_seed = 3407,
139
+ jit_mode_eval = False,
140
+ use_ipex = False,
141
+ bf16 = False,
142
+ fp16 = False,
143
+ fp16_opt_level = 'O1',
144
+ half_precision_backend = 'auto',
145
+ bf16_full_eval = False,
146
+ fp16_full_eval = False,
147
+ tf32 = None,
148
+ local_rank = -1,
149
+ ddp_backend = None,
150
+ tpu_num_cores = None,
151
+ tpu_metrics_debug = False,
152
+ debug = '',
153
+ dataloader_drop_last = False,
154
+ eval_steps = None,
155
+ dataloader_num_workers = 0,
156
+ dataloader_prefetch_factor = None,
157
+ past_index = -1,
158
+ run_name = None,
159
+ disable_tqdm = None,
160
+ remove_unused_columns = True,
161
+ label_names = None,
162
+ load_best_model_at_end = False,
163
+ metric_for_best_model = None,
164
+ greater_is_better = None,
165
+ ignore_data_skip = False,
166
+ fsdp = '',
167
+ fsdp_min_num_params = 0,
168
+ fsdp_config = None,
169
+ tp_size = 0,
170
+ fsdp_transformer_layer_cls_to_wrap = None,
171
+ accelerator_config = None,
172
+ deepspeed = None,
173
+ label_smoothing_factor = 0.0,
174
+ optim = 'adamw_8bit',
175
+ optim_args = None,
176
+ adafactor = False,
177
+ group_by_length = False,
178
+ length_column_name = 'length',
179
+ report_to = None,
180
+ ddp_find_unused_parameters = None,
181
+ ddp_bucket_cap_mb = None,
182
+ ddp_broadcast_buffers = None,
183
+ dataloader_pin_memory = True,
184
+ dataloader_persistent_workers = False,
185
+ skip_memory_metrics = True,
186
+ use_legacy_prediction_loop = False,
187
+ push_to_hub = False,
188
+ resume_from_checkpoint = None,
189
+ hub_model_id = None,
190
+ hub_strategy = 'every_save',
191
+ hub_token = None,
192
+ hub_private_repo = None,
193
+ hub_always_push = False,
194
+ gradient_checkpointing = False,
195
+ gradient_checkpointing_kwargs = None,
196
+ include_inputs_for_metrics = False,
197
+ eval_do_concat_batches = True,
198
+ fp16_backend = 'auto',
199
+ push_to_hub_model_id = None,
200
+ push_to_hub_organization = None,
201
+ push_to_hub_token = None,
202
+ mp_parameters = '',
203
+ auto_find_batch_size = False,
204
+ full_determinism = False,
205
+ torchdynamo = None,
206
+ ray_scope = 'last',
207
+ ddp_timeout = 1800,
208
+ torch_compile = False,
209
+ torch_compile_backend = None,
210
+ torch_compile_mode = None,
211
+ include_tokens_per_second = False,
212
+ include_num_input_tokens_seen = False,
213
+ neftune_noise_alpha = None,
214
+ optim_target_modules = None,
215
+ batch_eval_metrics = False,
216
+ eval_on_start = False,
217
+ use_liger_kernel = False,
218
+ eval_use_gather_object = False,
219
+ average_tokens_across_devices = False,
220
+ dataset_num_proc = None,
221
+ num_mini_batches = 1,
222
+ total_episodes = None,
223
+ local_rollout_forward_batch_size = 64,
224
+ num_sample_generations = 10,
225
+ response_length = 53,
226
+ stop_token = None,
227
+ stop_token_id = None,
228
+ temperature = 0.7,
229
+ missing_eos_penalty = None,
230
+ sft_model_path = 'EleutherAI/pythia-160m',
231
+ world_size = None,
232
+ num_total_batches = None,
233
+ micro_batch_size = None,
234
+ local_batch_size = None,
235
+ batch_size = None,
236
+ local_mini_batch_size = None,
237
+ mini_batch_size = None,
238
+ exp_name = 'ppo_config',
239
+ reward_model_path = 'EleutherAI/pythia-160m',
240
+ model_adapter_name = None,
241
+ ref_adapter_name = None,
242
+ num_ppo_epochs = 4,
243
+ whiten_rewards = False,
244
+ kl_coef = 0.05,
245
+ cliprange = 0.2,
246
+ vf_coef = 0.1,
247
+ cliprange_value = 0.2,
248
+ gamma = 1.0,
249
+ lam = 0.95,
250
+ ds3_gather_for_generation = True,
251
+ vllm_sampling_params = None,
252
+ unsloth_num_chunks = -1,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = cpu_count()
263
+
264
+ super().__init__(
265
+ output_dir = output_dir,
266
+ overwrite_output_dir = overwrite_output_dir,
267
+ do_train = do_train,
268
+ do_eval = do_eval,
269
+ do_predict = do_predict,
270
+ eval_strategy = eval_strategy,
271
+ prediction_loss_only = prediction_loss_only,
272
+ per_device_train_batch_size = per_device_train_batch_size,
273
+ per_device_eval_batch_size = per_device_eval_batch_size,
274
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
275
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
276
+ gradient_accumulation_steps = gradient_accumulation_steps,
277
+ eval_accumulation_steps = eval_accumulation_steps,
278
+ eval_delay = eval_delay,
279
+ torch_empty_cache_steps = torch_empty_cache_steps,
280
+ learning_rate = learning_rate,
281
+ weight_decay = weight_decay,
282
+ adam_beta1 = adam_beta1,
283
+ adam_beta2 = adam_beta2,
284
+ adam_epsilon = adam_epsilon,
285
+ max_grad_norm = max_grad_norm,
286
+ num_train_epochs = num_train_epochs,
287
+ max_steps = max_steps,
288
+ lr_scheduler_type = lr_scheduler_type,
289
+ warmup_ratio = warmup_ratio,
290
+ warmup_steps = warmup_steps,
291
+ log_level = log_level,
292
+ log_level_replica = log_level_replica,
293
+ log_on_each_node = log_on_each_node,
294
+ logging_dir = logging_dir,
295
+ logging_strategy = logging_strategy,
296
+ logging_first_step = logging_first_step,
297
+ logging_steps = logging_steps,
298
+ logging_nan_inf_filter = logging_nan_inf_filter,
299
+ save_strategy = save_strategy,
300
+ save_steps = save_steps,
301
+ save_total_limit = save_total_limit,
302
+ save_safetensors = save_safetensors,
303
+ save_on_each_node = save_on_each_node,
304
+ save_only_model = save_only_model,
305
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
306
+ no_cuda = no_cuda,
307
+ use_cpu = use_cpu,
308
+ use_mps_device = use_mps_device,
309
+ seed = seed,
310
+ data_seed = data_seed,
311
+ jit_mode_eval = jit_mode_eval,
312
+ use_ipex = use_ipex,
313
+ bf16 = bf16,
314
+ fp16 = fp16,
315
+ fp16_opt_level = fp16_opt_level,
316
+ half_precision_backend = half_precision_backend,
317
+ bf16_full_eval = bf16_full_eval,
318
+ fp16_full_eval = fp16_full_eval,
319
+ tf32 = tf32,
320
+ local_rank = local_rank,
321
+ ddp_backend = ddp_backend,
322
+ tpu_num_cores = tpu_num_cores,
323
+ tpu_metrics_debug = tpu_metrics_debug,
324
+ debug = debug,
325
+ dataloader_drop_last = dataloader_drop_last,
326
+ eval_steps = eval_steps,
327
+ dataloader_num_workers = dataloader_num_workers,
328
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
329
+ past_index = past_index,
330
+ run_name = run_name,
331
+ disable_tqdm = disable_tqdm,
332
+ remove_unused_columns = remove_unused_columns,
333
+ label_names = label_names,
334
+ load_best_model_at_end = load_best_model_at_end,
335
+ metric_for_best_model = metric_for_best_model,
336
+ greater_is_better = greater_is_better,
337
+ ignore_data_skip = ignore_data_skip,
338
+ fsdp = fsdp,
339
+ fsdp_min_num_params = fsdp_min_num_params,
340
+ fsdp_config = fsdp_config,
341
+ tp_size = tp_size,
342
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
343
+ accelerator_config = accelerator_config,
344
+ deepspeed = deepspeed,
345
+ label_smoothing_factor = label_smoothing_factor,
346
+ optim = optim,
347
+ optim_args = optim_args,
348
+ adafactor = adafactor,
349
+ group_by_length = group_by_length,
350
+ length_column_name = length_column_name,
351
+ report_to = report_to,
352
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
353
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
354
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
355
+ dataloader_pin_memory = dataloader_pin_memory,
356
+ dataloader_persistent_workers = dataloader_persistent_workers,
357
+ skip_memory_metrics = skip_memory_metrics,
358
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
359
+ push_to_hub = push_to_hub,
360
+ resume_from_checkpoint = resume_from_checkpoint,
361
+ hub_model_id = hub_model_id,
362
+ hub_strategy = hub_strategy,
363
+ hub_token = hub_token,
364
+ hub_private_repo = hub_private_repo,
365
+ hub_always_push = hub_always_push,
366
+ gradient_checkpointing = gradient_checkpointing,
367
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
368
+ include_inputs_for_metrics = include_inputs_for_metrics,
369
+ eval_do_concat_batches = eval_do_concat_batches,
370
+ fp16_backend = fp16_backend,
371
+ push_to_hub_model_id = push_to_hub_model_id,
372
+ push_to_hub_organization = push_to_hub_organization,
373
+ push_to_hub_token = push_to_hub_token,
374
+ mp_parameters = mp_parameters,
375
+ auto_find_batch_size = auto_find_batch_size,
376
+ full_determinism = full_determinism,
377
+ torchdynamo = torchdynamo,
378
+ ray_scope = ray_scope,
379
+ ddp_timeout = ddp_timeout,
380
+ torch_compile = torch_compile,
381
+ torch_compile_backend = torch_compile_backend,
382
+ torch_compile_mode = torch_compile_mode,
383
+ include_tokens_per_second = include_tokens_per_second,
384
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
385
+ neftune_noise_alpha = neftune_noise_alpha,
386
+ optim_target_modules = optim_target_modules,
387
+ batch_eval_metrics = batch_eval_metrics,
388
+ eval_on_start = eval_on_start,
389
+ use_liger_kernel = use_liger_kernel,
390
+ eval_use_gather_object = eval_use_gather_object,
391
+ average_tokens_across_devices = average_tokens_across_devices,
392
+ dataset_num_proc = dataset_num_proc,
393
+ num_mini_batches = num_mini_batches,
394
+ total_episodes = total_episodes,
395
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
396
+ num_sample_generations = num_sample_generations,
397
+ response_length = response_length,
398
+ stop_token = stop_token,
399
+ stop_token_id = stop_token_id,
400
+ temperature = temperature,
401
+ missing_eos_penalty = missing_eos_penalty,
402
+ sft_model_path = sft_model_path,
403
+ world_size = world_size,
404
+ num_total_batches = num_total_batches,
405
+ micro_batch_size = micro_batch_size,
406
+ local_batch_size = local_batch_size,
407
+ batch_size = batch_size,
408
+ local_mini_batch_size = local_mini_batch_size,
409
+ mini_batch_size = mini_batch_size,
410
+ exp_name = exp_name,
411
+ reward_model_path = reward_model_path,
412
+ model_adapter_name = model_adapter_name,
413
+ ref_adapter_name = ref_adapter_name,
414
+ num_ppo_epochs = num_ppo_epochs,
415
+ whiten_rewards = whiten_rewards,
416
+ kl_coef = kl_coef,
417
+ cliprange = cliprange,
418
+ vf_coef = vf_coef,
419
+ cliprange_value = cliprange_value,
420
+ gamma = gamma,
421
+ lam = lam,
422
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
423
+ self.vllm_sampling_params = vllm_sampling_params
424
+ self.unsloth_num_chunks = unsloth_num_chunks
425
+ pass
426
+
427
+ class _UnslothPPOTrainer(Trainer):
428
+ _tag_names = ["trl", "ppo"]
429
+
430
+ def __init__(
431
+ self,
432
+ args: PPOConfig,
433
+ processing_class: Optional[
434
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
435
+ ],
436
+ model: nn.Module,
437
+ ref_model: Optional[nn.Module],
438
+ reward_model: nn.Module,
439
+ train_dataset: Dataset,
440
+ value_model: Optional[nn.Module] = None,
441
+ data_collator: Optional[DataCollatorWithPadding] = None,
442
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
443
+ # less commonly used
444
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
445
+ callbacks: Optional[list[TrainerCallback]] = None,
446
+ peft_config: Optional["PeftConfig"] = None,
447
+ ) -> None:
448
+ if ref_model is model:
449
+ raise ValueError(
450
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
451
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
452
+ )
453
+
454
+ self.args = args
455
+ self.processing_class = processing_class
456
+ self.policy_model = model
457
+
458
+ # Define the collator if not provided
459
+ if data_collator is None:
460
+ data_collator = DataCollatorWithPadding(self.processing_class)
461
+
462
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
463
+ if args.stop_token and args.stop_token_id:
464
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
465
+ elif args.stop_token:
466
+ if args.stop_token == "eos":
467
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
468
+ else:
469
+ raise ValueError(
470
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
471
+ )
472
+ else:
473
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
474
+
475
+ # peft support
476
+ if not is_peft_available() and peft_config is not None:
477
+ raise ImportError(
478
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
479
+ )
480
+ elif is_peft_available() and peft_config is not None:
481
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
482
+ if isinstance(self.policy_model, PeftModel):
483
+ self.policy_model = self.policy_model.merge_and_unload()
484
+
485
+ # get peft model with the given config
486
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
487
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
488
+ peft_module_casting_to_bf16(self.policy_model)
489
+
490
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
491
+ self.model_adapter_name = args.model_adapter_name
492
+ self.ref_adapter_name = args.ref_adapter_name
493
+
494
+ if ref_model:
495
+ self.ref_model = ref_model
496
+ elif self.is_peft_model:
497
+ self.ref_model = None
498
+ else:
499
+ self.ref_model = create_reference_model(self.policy_model)
500
+
501
+ self.reward_model = reward_model
502
+ self.train_dataset = train_dataset
503
+ self.train_dataset_len = len(train_dataset)
504
+ self.value_model = value_model
505
+ self.data_collator = data_collator
506
+ self.eval_dataset = eval_dataset
507
+ self.optimizer, self.lr_scheduler = optimizers
508
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
509
+
510
+ #########
511
+ # calculate various batch sizes
512
+ #########
513
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
514
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
515
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
516
+ self.accelerator = accelerator
517
+ args.world_size = accelerator.num_processes
518
+ args.local_batch_size = (
519
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
520
+ )
521
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
522
+ args.batch_size = int(args.local_batch_size * args.world_size)
523
+ args.mini_batch_size = exact_div(
524
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
525
+ )
526
+ args.local_mini_batch_size = exact_div(
527
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
528
+ )
529
+ if args.whiten_rewards:
530
+ assert (
531
+ args.local_mini_batch_size >= 8
532
+ ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
533
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
534
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
535
+ args.num_total_batches = math.ceil(
536
+ args.total_episodes / args.batch_size
537
+ ) # we may train for more than `total_episodes`
538
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
539
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
540
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
541
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
542
+ if args.num_sample_generations > 0:
543
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
544
+ self.local_dataloader_batch_size = args.local_batch_size
545
+
546
+ #########
547
+ # setup model, optimizer, and others
548
+ #########
549
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
550
+ if module is not None:
551
+ disable_dropout_in_model(module)
552
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
553
+ self.model.config = self.policy_model.config # needed for pushing to hub
554
+ self.create_optimizer_and_scheduler(
555
+ num_training_steps=args.num_total_batches
556
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
557
+
558
+ #########
559
+ ### trainer specifics
560
+ #########
561
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
562
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
563
+ self.callback_handler = CallbackHandler(
564
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
565
+ )
566
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
567
+ self.control = TrainerControl()
568
+ self.state = OnlineTrainerState(
569
+ is_local_process_zero=self.is_local_process_zero(),
570
+ is_world_process_zero=self.is_world_process_zero(),
571
+ stateful_callbacks=[
572
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
573
+ ],
574
+ )
575
+ self.current_flos = 0
576
+ self.hp_search_backend = None
577
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
578
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
579
+ # Create distant repo and output directory if needed
580
+ self.hub_model_id = None
581
+ if self.args.push_to_hub:
582
+ self.init_hf_repo()
583
+ if self.args.should_save:
584
+ os.makedirs(self.args.output_dir, exist_ok=True)
585
+
586
+ # Add tags for models that have been loaded with the correct transformers version
587
+ if hasattr(self.model, "add_model_tags"):
588
+ self.model.add_model_tags(self._tag_names)
589
+
590
+ #########
591
+ ### setup dataloader
592
+ #########
593
+ self.dataloader = DataLoader(
594
+ self.train_dataset,
595
+ batch_size=self.local_dataloader_batch_size,
596
+ shuffle=True,
597
+ collate_fn=self.data_collator,
598
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
599
+ )
600
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
601
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
602
+ torch.manual_seed(args.seed)
603
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
604
+ torch.manual_seed(self.local_seed) # reset the local seed again
605
+
606
+ self.eval_dataloader = DataLoader(
607
+ self.eval_dataset,
608
+ batch_size=args.per_device_eval_batch_size,
609
+ collate_fn=self.data_collator,
610
+ drop_last=True,
611
+ ) # no need to shuffle eval dataset
612
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
613
+
614
+ if self.is_deepspeed_enabled:
615
+ self.reward_model = prepare_deepspeed(
616
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
617
+ )
618
+
619
+ if self.ref_model is None:
620
+ if not self.is_peft_model:
621
+ raise ValueError("No reference model and model is not a Peft model.")
622
+ else:
623
+ self.ref_model = prepare_deepspeed(
624
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
625
+ )
626
+ else:
627
+ if self.ref_model is None:
628
+ if not self.is_peft_model:
629
+ raise ValueError("No reference model and model is not a Peft model.")
630
+ else:
631
+ self.ref_model = self.ref_model.to(self.accelerator.device)
632
+ self.reward_model = self.reward_model.to(self.accelerator.device)
633
+
634
+ def get_train_dataloader(self) -> DataLoader:
635
+ return self.dataloader
636
+
637
+ def get_eval_dataloader(self) -> DataLoader:
638
+ return self.eval_dataloader
639
+
640
+ @contextmanager
641
+ def null_ref_context(self):
642
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
643
+ with (
644
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
645
+ if self.is_peft_model and not self.ref_adapter_name
646
+ else nullcontext()
647
+ ):
648
+ if self.ref_adapter_name:
649
+ self.model.policy.set_adapter(self.ref_adapter_name)
650
+ yield
651
+ if self.ref_adapter_name:
652
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
653
+
654
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
655
+ backup_model = self.model
656
+ self.model = self.model.policy # save only the policy
657
+
658
+ if self.is_deepspeed_enabled:
659
+ backup_deepspeed = self.deepspeed
660
+ self.deepspeed = self.model
661
+
662
+ super().save_model(output_dir, _internal_call)
663
+
664
+ self.model = backup_model
665
+
666
+ if self.is_deepspeed_enabled:
667
+ self.deepspeed = backup_deepspeed
668
+
669
+ def train(self):
670
+ args = self.args
671
+ accelerator = self.accelerator
672
+ optimizer = self.optimizer
673
+ model = self.model
674
+ ref_policy = self.ref_model
675
+ reward_model = self.reward_model
676
+ processing_class = self.processing_class
677
+ dataloader = self.dataloader
678
+ device = accelerator.device
679
+
680
+ def repeat_generator():
681
+ while True:
682
+ yield from dataloader
683
+
684
+ iter_dataloader = iter(repeat_generator())
685
+ generation_config = GenerationConfig(
686
+ max_new_tokens=args.response_length,
687
+ temperature=(args.temperature + 1e-7),
688
+ top_k=0.0,
689
+ top_p=1.0,
690
+ do_sample=True,
691
+ )
692
+
693
+ accelerator.print("===training policy===")
694
+ start_time = time.time()
695
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
696
+ approxkl_stats = torch.zeros(stats_shape, device=device)
697
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
698
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
699
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
700
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
701
+ entropy_stats = torch.zeros(stats_shape, device=device)
702
+ ratio_stats = torch.zeros(stats_shape, device=device)
703
+ model.train()
704
+
705
+ # trainer state initialization
706
+ self.state.global_step = 0
707
+ self.state.episode = 0
708
+ self.state.max_steps = args.num_total_batches * args.num_mini_batches
709
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
710
+ # Compute absolute values for logging, eval, and save if given as ratio
711
+ if args.logging_steps is not None:
712
+ if args.logging_steps < 1:
713
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
714
+ else:
715
+ self.state.logging_steps = args.logging_steps
716
+ if args.eval_steps is not None:
717
+ if args.eval_steps < 1:
718
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
719
+ else:
720
+ self.state.eval_steps = args.eval_steps
721
+ if args.save_steps is not None:
722
+ if args.save_steps < 1:
723
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
724
+ else:
725
+ self.state.save_steps = args.save_steps
726
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
727
+
728
+ # backward compatibility
729
+ if self.is_deepspeed_enabled:
730
+ self.deepspeed = self.model
731
+ self.model_wrapped = self.model
732
+
733
+ for update in range(1, args.num_total_batches + 1):
734
+ self.state.episode += 1 * args.batch_size
735
+ data = next(iter_dataloader)
736
+ with torch.no_grad():
737
+ queries = data["input_ids"].to(device)
738
+ context_length = queries.shape[1]
739
+ responses = []
740
+ postprocessed_responses = []
741
+ logprobs = []
742
+ ref_logprobs = []
743
+ scores = []
744
+ sequence_lengths = []
745
+ values = []
746
+ with unwrap_model_for_generation(
747
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
748
+ ) as unwrapped_model:
749
+ query_responses, logitss = batch_generation(
750
+ unwrapped_model.policy,
751
+ queries,
752
+ args.local_rollout_forward_batch_size,
753
+ processing_class.pad_token_id,
754
+ generation_config,
755
+ )
756
+
757
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
758
+ query = queries[i : i + args.local_rollout_forward_batch_size]
759
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
760
+ response = query_response[:, context_length:]
761
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
762
+ logprob = selective_log_softmax(logits, response)
763
+ del logits
764
+ torch.cuda.empty_cache()
765
+
766
+ if ref_policy is None:
767
+ with self.null_ref_context():
768
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
769
+ else:
770
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
771
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
772
+ ref_logits /= args.temperature + 1e-7
773
+ ref_logprob = selective_log_softmax(ref_logits, response)
774
+ del ref_output, ref_logits
775
+ torch.cuda.empty_cache()
776
+
777
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
778
+ postprocessed_response = response
779
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
780
+ postprocessed_response = truncate_response(
781
+ self.stop_token_id, processing_class.pad_token_id, response
782
+ )
783
+
784
+ # Response Processing 2. run reward model on the truncated responses
785
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
786
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
787
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
788
+ full_value, _, _ = get_reward(
789
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
790
+ )
791
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
792
+ _, score, _ = get_reward(
793
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
794
+ )
795
+
796
+ responses.append(response)
797
+ postprocessed_responses.append(postprocessed_response)
798
+ logprobs.append(logprob)
799
+ ref_logprobs.append(ref_logprob)
800
+ sequence_lengths.append(sequence_length)
801
+ scores.append(score)
802
+ values.append(value)
803
+ responses = torch.cat(responses, 0)
804
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
805
+ logprobs = torch.cat(logprobs, 0)
806
+ ref_logprobs = torch.cat(ref_logprobs, 0)
807
+ sequence_lengths = torch.cat(sequence_lengths, 0)
808
+ scores = torch.cat(scores, 0)
809
+ values = torch.cat(values, 0)
810
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
811
+ torch.cuda.empty_cache()
812
+ gc.collect()
813
+
814
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
815
+ # Completions not passing that filter will receive a lower score.
816
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
817
+ if self.args.missing_eos_penalty is not None:
818
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
819
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
820
+
821
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
822
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
823
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
824
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
825
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
826
+ sequence_lengths_p1 = sequence_lengths + 1
827
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
828
+ values = torch.masked_fill(values, padding_mask_p1, 0)
829
+
830
+ # 4. compute rewards
831
+ kl = logprobs - ref_logprobs
832
+ non_score_reward = -args.kl_coef * kl
833
+ rewards = non_score_reward.clone()
834
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
835
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
836
+ rewards[[actual_start, actual_end]] += scores
837
+
838
+ # 5. whiten rewards
839
+ if args.whiten_rewards:
840
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
841
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
842
+
843
+ # 6. compute advantages and returns
844
+ lastgaelam = 0
845
+ advantages_reversed = []
846
+ gen_length = responses.shape[1]
847
+ for t in reversed(range(gen_length)):
848
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
849
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
850
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
851
+ advantages_reversed.append(lastgaelam)
852
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
853
+ returns = advantages + values
854
+ advantages = masked_whiten(advantages, ~padding_mask)
855
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
856
+ torch.cuda.empty_cache()
857
+
858
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
859
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
860
+ b_inds = np.random.permutation(args.local_batch_size)
861
+ minibatch_idx = 0
862
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
863
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
864
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
865
+ gradient_accumulation_idx = 0
866
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
867
+ with accelerator.accumulate(model):
868
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
869
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
870
+ mb_advantage = advantages[micro_batch_inds]
871
+ mb_responses = responses[micro_batch_inds]
872
+ mb_query_responses = query_responses[micro_batch_inds]
873
+ mb_logprobs = logprobs[micro_batch_inds]
874
+ mb_return = returns[micro_batch_inds]
875
+ mb_values = values[micro_batch_inds]
876
+
877
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
878
+ logits = output.logits[:, context_length - 1 : -1]
879
+ logits /= args.temperature + 1e-7
880
+ new_logprobs = selective_log_softmax(logits, mb_responses)
881
+ new_logprobs = torch.masked_fill(
882
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
883
+ )
884
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
885
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
886
+ vpredclipped = torch.clamp(
887
+ vpred,
888
+ mb_values - args.cliprange_value,
889
+ mb_values + args.cliprange_value,
890
+ )
891
+ vf_losses1 = torch.square(vpred - mb_return)
892
+ vf_losses2 = torch.square(vpredclipped - mb_return)
893
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
894
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
895
+ vf_clipfrac = masked_mean(
896
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
897
+ )
898
+ logprobs_diff = new_logprobs - mb_logprobs
899
+ ratio = torch.exp(logprobs_diff)
900
+ pg_losses = -mb_advantage * ratio
901
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
902
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
903
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
904
+ loss = pg_loss + args.vf_coef * vf_loss
905
+ accelerator.backward(loss)
906
+ optimizer.step()
907
+ optimizer.zero_grad()
908
+ with torch.no_grad():
909
+ pg_clipfrac = masked_mean(
910
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
911
+ )
912
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
913
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
914
+ approxkl = 0.5 * (logprobs_diff**2).mean()
915
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
916
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
917
+ pg_clipfrac
918
+ )
919
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
920
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
921
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
922
+ vf_clipfrac
923
+ )
924
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
925
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
926
+ gradient_accumulation_idx += 1
927
+ minibatch_idx += 1
928
+ # del everything and empty cache
929
+ # fmt: off
930
+ del (
931
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
932
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
933
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
934
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
935
+ )
936
+ # fmt: on
937
+ torch.cuda.empty_cache()
938
+ with torch.no_grad():
939
+ mean_kl = kl.sum(1).mean()
940
+ mean_entropy = (-logprobs).sum(1).mean()
941
+ mean_non_score_reward = non_score_reward.sum(1).mean()
942
+ rlhf_reward = mean_non_score_reward + scores.mean()
943
+ eps = int(self.state.episode / (time.time() - start_time))
944
+ metrics = {}
945
+ metrics["eps"] = eps
946
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
947
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
948
+ metrics["objective/non_score_reward"] = (
949
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
950
+ )
951
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
952
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
953
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
954
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
955
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
956
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
957
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
958
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
959
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
960
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
961
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
962
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
963
+ metrics["episode"] = self.state.episode
964
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
965
+ self.state.global_step += 1
966
+ self.log(metrics)
967
+
968
+ self.lr_scheduler.step()
969
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
970
+ if self.control.should_save:
971
+ self._save_checkpoint(model, trial=None)
972
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
973
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
974
+ torch.cuda.empty_cache()
975
+ gc.collect()
976
+
977
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
978
+ self.generate_completions(sampling=True)
979
+ torch.cuda.empty_cache()
980
+ del (
981
+ query_responses,
982
+ responses,
983
+ postprocessed_responses,
984
+ logprobs,
985
+ ref_logprobs,
986
+ values,
987
+ sequence_lengths,
988
+ contain_eos_token,
989
+ sequence_lengths_p1,
990
+ response_idxs,
991
+ padding_mask,
992
+ padding_mask_p1,
993
+ rewards,
994
+ actual_start,
995
+ actual_end,
996
+ advantages,
997
+ returns,
998
+ )
999
+ torch.cuda.empty_cache()
1000
+
1001
+ # HF trainer specifics
1002
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1003
+ if self.control.should_save:
1004
+ self._save_checkpoint(model, trial=None, metrics=None)
1005
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1006
+
1007
+ def generate_completions(self, sampling: bool = False):
1008
+ args = self.args
1009
+ processing_class = self.processing_class
1010
+ generation_config = GenerationConfig(
1011
+ max_new_tokens=self.args.response_length,
1012
+ temperature=(0.01 + 1e-7),
1013
+ top_k=0.0,
1014
+ top_p=1.0,
1015
+ do_sample=True,
1016
+ )
1017
+
1018
+ table = defaultdict(list)
1019
+ with unwrap_model_for_generation(
1020
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1021
+ ) as unwrapped_model:
1022
+ for batch in self.eval_dataloader:
1023
+ query = batch["input_ids"]
1024
+ with torch.no_grad():
1025
+ context_length = query.shape[1]
1026
+ query_response, _ = batch_generation(
1027
+ unwrapped_model.policy,
1028
+ query,
1029
+ query.shape[0],
1030
+ processing_class.pad_token_id,
1031
+ generation_config,
1032
+ )
1033
+ response = query_response[:, context_length:]
1034
+ postprocessed_response = response
1035
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1036
+ postprocessed_response = truncate_response(
1037
+ self.stop_token_id, processing_class.pad_token_id, response
1038
+ )
1039
+ table["query"].extend(
1040
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1041
+ )
1042
+ table["model response"].extend(
1043
+ gather_object(processing_class.batch_decode(postprocessed_response))
1044
+ )
1045
+
1046
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1047
+ _, score, _ = get_reward(
1048
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1049
+ )
1050
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1051
+
1052
+ if sampling:
1053
+ break
1054
+ df = pd.DataFrame(table)
1055
+
1056
+ if self.accelerator.is_main_process:
1057
+ print_rich_table(df.iloc[0 : 0 + 5])
1058
+ if "wandb" in args.report_to:
1059
+ import wandb
1060
+
1061
+ if wandb.run is not None:
1062
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1063
+
1064
+ if "comet_ml" in args.report_to:
1065
+ log_table_to_comet_experiment(
1066
+ name="completions.csv",
1067
+ table=df,
1068
+ )
1069
+
1070
+ def create_model_card(
1071
+ self,
1072
+ model_name: Optional[str] = None,
1073
+ dataset_name: Optional[str] = None,
1074
+ tags: Union[str, list[str], None] = None,
1075
+ ):
1076
+ """
1077
+ Creates a draft of a model card using the information available to the `Trainer`.
1078
+
1079
+ Args:
1080
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1081
+ Name of the model.
1082
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1083
+ Name of the dataset used for training.
1084
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1085
+ Tags to be associated with the model card.
1086
+ """
1087
+ if not self.is_world_process_zero():
1088
+ return
1089
+
1090
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1091
+ base_model = self.model.config._name_or_path
1092
+ else:
1093
+ base_model = None
1094
+
1095
+ tags = tags or []
1096
+ if isinstance(tags, str):
1097
+ tags = [tags]
1098
+
1099
+ if hasattr(self.model.config, "unsloth_version"):
1100
+ tags.append("unsloth")
1101
+
1102
+ citation = textwrap.dedent("""\
1103
+ @article{mziegler2019fine-tuning,
1104
+ title = {{Fine-Tuning Language Models from Human Preferences}},
1105
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1106
+ year = 2019,
1107
+ eprint = {arXiv:1909.08593}
1108
+ }""")
1109
+
1110
+ model_card = generate_model_card(
1111
+ base_model=base_model,
1112
+ model_name=model_name,
1113
+ hub_model_id=self.hub_model_id,
1114
+ dataset_name=dataset_name,
1115
+ tags=tags,
1116
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1117
+ comet_url=get_comet_experiment_url(),
1118
+ trainer_name="PPO",
1119
+ trainer_citation=citation,
1120
+ paper_title="Fine-Tuning Language Models from Human Preferences",
1121
+ paper_id="1909.08593",
1122
+ )
1123
+
1124
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1125
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1126
+ """
1127
+
1128
+ """
1129
+ def __init__(
1130
+ self,
1131
+ args,
1132
+ processing_class,
1133
+ model,
1134
+ ref_model,
1135
+ reward_model,
1136
+ train_dataset,
1137
+ value_model = None,
1138
+ data_collator = None,
1139
+ eval_dataset = None,
1140
+ callbacks = None,
1141
+ peft_config = None,
1142
+ **kwargs
1143
+ ):
1144
+ if args is None: args = UnslothPPOConfig()
1145
+ use_bf16 = getattr(args, 'bf16', False)
1146
+ use_fp16 = getattr(args, 'fp16', False)
1147
+ force_float32 = False
1148
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1149
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1150
+ force_float32 = True
1151
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1152
+ dtype = getattr(model.config, 'torch_dtype', None)
1153
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1154
+ from unsloth_zoo.utils import _get_dtype
1155
+ dtype = _get_dtype(dtype)
1156
+ float16 = dtype == torch.float16
1157
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1158
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1159
+ if force_float32:
1160
+ args.fp16 = False
1161
+ args.bf16 = False
1162
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1163
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1164
+ args.fp16 = float16
1165
+ args.bf16 = not float16
1166
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1167
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1168
+ args.eval_strategy = 'steps'
1169
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1170
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1171
+ if ga_steps is not None and ga_steps > 1:
1172
+ from transformers import __version__ as transformers_version
1173
+ if Version(transformers_version) <= Version('4.45.2'):
1174
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1175
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1176
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1177
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1178
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1179
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1180
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1181
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1182
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1183
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1184
+ if force_float32:
1185
+ args.bf16_full_eval = False
1186
+ args.fp16_full_eval = False
1187
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1188
+ args.bf16_full_eval = True
1189
+ args.fp16_full_eval = False
1190
+ elif not bf16_full_eval and not fp16_full_eval:
1191
+ args.bf16_full_eval = args.bf16
1192
+ args.fp16_full_eval = args.fp16
1193
+ _output_logits = False
1194
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1195
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1196
+ if _output_logits:
1197
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1198
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1199
+ pass
1200
+ else:
1201
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1202
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1203
+ if args_max_seq_length is None and model_max_seq_length is not None:
1204
+ max_seq_length = model.max_seq_length
1205
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1206
+ if model is not None and hasattr(model, 'for_training'):
1207
+ model.for_training()
1208
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1209
+ if 'processing_class' in locals():
1210
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1211
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1212
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1213
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1214
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1215
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1216
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1217
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1218
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1219
+ else:
1220
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1221
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1222
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1223
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1224
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1225
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1226
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1227
+ else:
1228
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1229
+ other_metrics = []
1230
+
1231
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1232
+ PatchRLStatistics('ppo_trainer', other_metrics)
1233
+
1234
+ super().__init__(
1235
+ args = args,
1236
+ processing_class = processing_class,
1237
+ model = model,
1238
+ ref_model = ref_model,
1239
+ reward_model = reward_model,
1240
+ train_dataset = train_dataset,
1241
+ value_model = value_model,
1242
+ data_collator = data_collator,
1243
+ eval_dataset = eval_dataset,
1244
+ callbacks = callbacks,
1245
+ peft_config = peft_config,**kwargs)
1246
+ if hasattr(self, 'neftune_hook_handle'):
1247
+ self.neftune_hook_handle.remove()
1248
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1249
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1250
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1251
+ pass
1252
+
1253
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPRMConfig(PRMConfig):
44
+ """
45
+
46
+ Configuration class for the [`PRMTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-5`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) used for truncation.
58
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
59
+ Maximum length of the prompt used for truncation.
60
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
61
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
62
+ disable_dropout (`bool`, *optional*, defaults to `True`):
63
+ Whether to disable dropout in the model.
64
+ step_separator (`str`, *optional*, defaults to `"\n"`):
65
+ Separator used to separate each step of the reasoning process.
66
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
67
+ Whether to train only on the last step.
68
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
69
+ Number of processes to use for processing the dataset.
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ push_to_hub_model_id = None,
189
+ push_to_hub_organization = None,
190
+ push_to_hub_token = None,
191
+ mp_parameters = '',
192
+ auto_find_batch_size = False,
193
+ full_determinism = False,
194
+ torchdynamo = None,
195
+ ray_scope = 'last',
196
+ ddp_timeout = 1800,
197
+ torch_compile = False,
198
+ torch_compile_backend = None,
199
+ torch_compile_mode = None,
200
+ include_tokens_per_second = False,
201
+ include_num_input_tokens_seen = False,
202
+ neftune_noise_alpha = None,
203
+ optim_target_modules = None,
204
+ batch_eval_metrics = False,
205
+ eval_on_start = False,
206
+ use_liger_kernel = False,
207
+ eval_use_gather_object = False,
208
+ average_tokens_across_devices = False,
209
+ max_length = 1024,
210
+ max_prompt_length = 512,
211
+ max_completion_length = None,
212
+ disable_dropout = True,
213
+ step_separator = '\
214
+ ',
215
+ train_on_last_step_only = False,
216
+ dataset_num_proc = None,
217
+ vllm_sampling_params = None,
218
+ unsloth_num_chunks = -1,
219
+ **kwargs,
220
+ ):
221
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
222
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
223
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
224
+ output_dir = 'unsloth_training_checkpoints'
225
+ save_strategy = 'no'
226
+ if dataset_num_proc is None:
227
+ from multiprocessing import cpu_count
228
+ dataset_num_proc = cpu_count()
229
+
230
+ super().__init__(
231
+ output_dir = output_dir,
232
+ overwrite_output_dir = overwrite_output_dir,
233
+ do_train = do_train,
234
+ do_eval = do_eval,
235
+ do_predict = do_predict,
236
+ eval_strategy = eval_strategy,
237
+ prediction_loss_only = prediction_loss_only,
238
+ per_device_train_batch_size = per_device_train_batch_size,
239
+ per_device_eval_batch_size = per_device_eval_batch_size,
240
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
241
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
242
+ gradient_accumulation_steps = gradient_accumulation_steps,
243
+ eval_accumulation_steps = eval_accumulation_steps,
244
+ eval_delay = eval_delay,
245
+ torch_empty_cache_steps = torch_empty_cache_steps,
246
+ learning_rate = learning_rate,
247
+ weight_decay = weight_decay,
248
+ adam_beta1 = adam_beta1,
249
+ adam_beta2 = adam_beta2,
250
+ adam_epsilon = adam_epsilon,
251
+ max_grad_norm = max_grad_norm,
252
+ num_train_epochs = num_train_epochs,
253
+ max_steps = max_steps,
254
+ lr_scheduler_type = lr_scheduler_type,
255
+ warmup_ratio = warmup_ratio,
256
+ warmup_steps = warmup_steps,
257
+ log_level = log_level,
258
+ log_level_replica = log_level_replica,
259
+ log_on_each_node = log_on_each_node,
260
+ logging_dir = logging_dir,
261
+ logging_strategy = logging_strategy,
262
+ logging_first_step = logging_first_step,
263
+ logging_steps = logging_steps,
264
+ logging_nan_inf_filter = logging_nan_inf_filter,
265
+ save_strategy = save_strategy,
266
+ save_steps = save_steps,
267
+ save_total_limit = save_total_limit,
268
+ save_safetensors = save_safetensors,
269
+ save_on_each_node = save_on_each_node,
270
+ save_only_model = save_only_model,
271
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
272
+ no_cuda = no_cuda,
273
+ use_cpu = use_cpu,
274
+ use_mps_device = use_mps_device,
275
+ seed = seed,
276
+ data_seed = data_seed,
277
+ jit_mode_eval = jit_mode_eval,
278
+ use_ipex = use_ipex,
279
+ bf16 = bf16,
280
+ fp16 = fp16,
281
+ fp16_opt_level = fp16_opt_level,
282
+ half_precision_backend = half_precision_backend,
283
+ bf16_full_eval = bf16_full_eval,
284
+ fp16_full_eval = fp16_full_eval,
285
+ tf32 = tf32,
286
+ local_rank = local_rank,
287
+ ddp_backend = ddp_backend,
288
+ tpu_num_cores = tpu_num_cores,
289
+ tpu_metrics_debug = tpu_metrics_debug,
290
+ debug = debug,
291
+ dataloader_drop_last = dataloader_drop_last,
292
+ eval_steps = eval_steps,
293
+ dataloader_num_workers = dataloader_num_workers,
294
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
295
+ past_index = past_index,
296
+ run_name = run_name,
297
+ disable_tqdm = disable_tqdm,
298
+ remove_unused_columns = remove_unused_columns,
299
+ label_names = label_names,
300
+ load_best_model_at_end = load_best_model_at_end,
301
+ metric_for_best_model = metric_for_best_model,
302
+ greater_is_better = greater_is_better,
303
+ ignore_data_skip = ignore_data_skip,
304
+ fsdp = fsdp,
305
+ fsdp_min_num_params = fsdp_min_num_params,
306
+ fsdp_config = fsdp_config,
307
+ tp_size = tp_size,
308
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
309
+ accelerator_config = accelerator_config,
310
+ deepspeed = deepspeed,
311
+ label_smoothing_factor = label_smoothing_factor,
312
+ optim = optim,
313
+ optim_args = optim_args,
314
+ adafactor = adafactor,
315
+ group_by_length = group_by_length,
316
+ length_column_name = length_column_name,
317
+ report_to = report_to,
318
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
319
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
320
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
321
+ dataloader_pin_memory = dataloader_pin_memory,
322
+ dataloader_persistent_workers = dataloader_persistent_workers,
323
+ skip_memory_metrics = skip_memory_metrics,
324
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
325
+ push_to_hub = push_to_hub,
326
+ resume_from_checkpoint = resume_from_checkpoint,
327
+ hub_model_id = hub_model_id,
328
+ hub_strategy = hub_strategy,
329
+ hub_token = hub_token,
330
+ hub_private_repo = hub_private_repo,
331
+ hub_always_push = hub_always_push,
332
+ gradient_checkpointing = gradient_checkpointing,
333
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
334
+ include_inputs_for_metrics = include_inputs_for_metrics,
335
+ eval_do_concat_batches = eval_do_concat_batches,
336
+ fp16_backend = fp16_backend,
337
+ push_to_hub_model_id = push_to_hub_model_id,
338
+ push_to_hub_organization = push_to_hub_organization,
339
+ push_to_hub_token = push_to_hub_token,
340
+ mp_parameters = mp_parameters,
341
+ auto_find_batch_size = auto_find_batch_size,
342
+ full_determinism = full_determinism,
343
+ torchdynamo = torchdynamo,
344
+ ray_scope = ray_scope,
345
+ ddp_timeout = ddp_timeout,
346
+ torch_compile = torch_compile,
347
+ torch_compile_backend = torch_compile_backend,
348
+ torch_compile_mode = torch_compile_mode,
349
+ include_tokens_per_second = include_tokens_per_second,
350
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
351
+ neftune_noise_alpha = neftune_noise_alpha,
352
+ optim_target_modules = optim_target_modules,
353
+ batch_eval_metrics = batch_eval_metrics,
354
+ eval_on_start = eval_on_start,
355
+ use_liger_kernel = use_liger_kernel,
356
+ eval_use_gather_object = eval_use_gather_object,
357
+ average_tokens_across_devices = average_tokens_across_devices,
358
+ max_length = max_length,
359
+ max_prompt_length = max_prompt_length,
360
+ max_completion_length = max_completion_length,
361
+ disable_dropout = disable_dropout,
362
+ step_separator = step_separator,
363
+ train_on_last_step_only = train_on_last_step_only,
364
+ dataset_num_proc = dataset_num_proc,**kwargs)
365
+ self.vllm_sampling_params = vllm_sampling_params
366
+ self.unsloth_num_chunks = unsloth_num_chunks
367
+ pass
368
+
369
+ class _UnslothPRMTrainer(Trainer):
370
+ """"""
371
+
372
+ _tag_names = ["trl", "prm"]
373
+
374
+ def __init__(
375
+ self,
376
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
377
+ args: Optional[PRMConfig] = None,
378
+ data_collator: Optional[DataCollator] = None,
379
+ train_dataset: Optional[Dataset] = None,
380
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
381
+ processing_class: Optional[
382
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
383
+ ] = None,
384
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
385
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
386
+ callbacks: Optional[list[TrainerCallback]] = None,
387
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
388
+ None,
389
+ None,
390
+ ),
391
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
392
+ peft_config: Optional[dict] = None,
393
+ ):
394
+ if not is_peft_available() and peft_config is not None:
395
+ raise ValueError(
396
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
397
+ )
398
+ elif is_peft_available() and peft_config is not None:
399
+ if not isinstance(model, PeftModel):
400
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
401
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
402
+ inspect.signature(prepare_model_for_kbit_training).parameters
403
+ )
404
+
405
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
406
+
407
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
408
+ warnings.warn(
409
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
410
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
411
+ )
412
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
413
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
414
+
415
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
416
+
417
+ model = model
418
+
419
+ # Disable dropout in the model
420
+ if args.disable_dropout:
421
+ disable_dropout_in_model(model)
422
+
423
+ if compute_metrics is None:
424
+ compute_metrics = compute_accuracy
425
+
426
+ if data_collator is None:
427
+ if processing_class is None:
428
+ raise ValueError(
429
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
430
+ )
431
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
432
+
433
+ if "input_ids" not in train_dataset.column_names:
434
+ with PartialState().local_main_process_first():
435
+ fn_kwargs = {
436
+ "tokenizer": processing_class,
437
+ "step_separator": args.step_separator,
438
+ "max_length": args.max_length,
439
+ "max_prompt_length": args.max_prompt_length,
440
+ "max_completion_length": args.max_completion_length,
441
+ "train_on_last_step_only": args.train_on_last_step_only,
442
+ }
443
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
444
+ train_dataset = train_dataset.map(
445
+ self.tokenize_row,
446
+ fn_kwargs=train_fn_kwargs,
447
+ num_proc=args.dataset_num_proc,
448
+ remove_columns=train_dataset.features,
449
+ desc="Tokenizing train dataset",
450
+ features=features.Features( # needed to avoid map to cast labels to bool
451
+ {
452
+ "labels": features.Sequence(features.Value("int64")),
453
+ "input_ids": features.Sequence(features.Value("int64")),
454
+ }
455
+ ),
456
+ )
457
+
458
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
459
+ if eval_dataset is not None:
460
+ eval_dataset = eval_dataset.map(
461
+ self.tokenize_row,
462
+ fn_kwargs=eval_fn_kwargs,
463
+ num_proc=args.dataset_num_proc,
464
+ remove_columns=eval_dataset.features,
465
+ desc="Tokenizing eval dataset",
466
+ features=features.Features( # needed to avoid map to cast labels to bool
467
+ {
468
+ "labels": features.Sequence(features.Value("int64")),
469
+ "input_ids": features.Sequence(features.Value("int64")),
470
+ }
471
+ ),
472
+ )
473
+
474
+ super().__init__(
475
+ model=model,
476
+ args=args,
477
+ data_collator=data_collator,
478
+ train_dataset=train_dataset,
479
+ eval_dataset=eval_dataset,
480
+ processing_class=processing_class,
481
+ model_init=model_init,
482
+ compute_metrics=compute_metrics,
483
+ callbacks=callbacks,
484
+ optimizers=optimizers,
485
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
486
+ )
487
+
488
+ # Add tags for models that have been loaded with the correct transformers version
489
+ if hasattr(self.model, "add_model_tags"):
490
+ self.model.add_model_tags(self._tag_names)
491
+
492
+ @staticmethod
493
+ def tokenize_row(
494
+ features,
495
+ tokenizer,
496
+ step_separator,
497
+ max_length,
498
+ max_prompt_length,
499
+ max_completion_length,
500
+ train_on_last_step_only,
501
+ is_eval,
502
+ ):
503
+ r"""
504
+ Tokenize a row of the dataset.
505
+
506
+ Args:
507
+ features (`dict[str, str]`):
508
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
509
+ tokenizer (`PreTrainedTokenizerBase`):
510
+ Tokenizer used to process the data.
511
+ step_separator (`str`):
512
+ Separator between steps in the completion.
513
+ max_length (`int` or `None`):
514
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
515
+ max_prompt_length (`int` or `None`):
516
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
517
+ max_completion_length (`int` or `None`):
518
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
519
+ train_on_last_step_only (`bool`):
520
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
521
+ token of the completion.
522
+ is_eval (`bool`):
523
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
524
+
525
+ Returns:
526
+ `dict[str, list[int]]`:
527
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
528
+
529
+ Example:
530
+ ```python
531
+ >>> from transformers import AutoTokenizer
532
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
533
+ >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
534
+ ... "completions": ["11 is greater than 8.",
535
+ ... "Hence, 9.11 > 9.8."],
536
+ ... "labels": [True, False]}
537
+ >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
538
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
539
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
540
+ ```
541
+ """
542
+ # Tokenize the prompt and completions
543
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
544
+ completions_ids = [
545
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
546
+ ]
547
+ if train_on_last_step_only and not is_eval:
548
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
549
+ else:
550
+ labels = [int(label) for label in features["labels"]]
551
+
552
+ # Get the ID of the separator token and add it to the completions
553
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
554
+ completions_ids = [completion + separator_ids for completion in completions_ids]
555
+
556
+ # Create the label
557
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
558
+
559
+ # Join the completions and labels steps
560
+ completion_ids = list(chain(*completions_ids))
561
+ labels = list(chain(*labels))
562
+
563
+ if tokenizer.bos_token_id is not None:
564
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
565
+
566
+ # Truncate prompt and completion sequences
567
+ if max_prompt_length is not None:
568
+ prompt_ids = prompt_ids[-max_prompt_length:]
569
+ if max_completion_length is not None:
570
+ completion_ids = completion_ids[:max_completion_length]
571
+ labels = labels[:max_completion_length]
572
+
573
+ input_ids = prompt_ids + completion_ids
574
+ labels = [-100] * len(prompt_ids) + labels
575
+
576
+ if max_length is not None:
577
+ input_ids = input_ids[:max_length]
578
+ labels = labels[:max_length]
579
+
580
+ return {"input_ids": input_ids, "labels": labels}
581
+
582
+ def create_model_card(
583
+ self,
584
+ model_name: Optional[str] = None,
585
+ dataset_name: Optional[str] = None,
586
+ tags: Union[str, list[str], None] = None,
587
+ ):
588
+ """
589
+ Creates a draft of a model card using the information available to the `Trainer`.
590
+
591
+ Args:
592
+ model_name (`str` or `None`, *optional*, defaults to `None`):
593
+ Name of the model.
594
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
595
+ Name of the dataset used for training.
596
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
597
+ Tags to be associated with the model card.
598
+ """
599
+ if not self.is_world_process_zero():
600
+ return
601
+
602
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
603
+ base_model = self.model.config._name_or_path
604
+ else:
605
+ base_model = None
606
+
607
+ tags = tags or []
608
+ if isinstance(tags, str):
609
+ tags = [tags]
610
+
611
+ if hasattr(self.model.config, "unsloth_version"):
612
+ tags.append("unsloth")
613
+
614
+ citation = textwrap.dedent("""\
615
+ @article{uesato2022solving,
616
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
617
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
618
+ year = 2022,
619
+ journal = {arXiv preprint arXiv:2211.14275}
620
+ }""")
621
+
622
+ model_card = generate_model_card(
623
+ base_model=base_model,
624
+ model_name=model_name,
625
+ hub_model_id=self.hub_model_id,
626
+ dataset_name=dataset_name,
627
+ tags=tags,
628
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
629
+ trainer_name="PRM",
630
+ trainer_citation=citation,
631
+ paper_title="Solving math word problems with process-and outcome-based feedback",
632
+ )
633
+
634
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
635
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
636
+ """
637
+
638
+ Initialize PRMTrainer.
639
+
640
+ Args:
641
+ model (`transformers.PreTrainedModel`):
642
+ The model to train, preferably an `AutoModelForTokenClassification`.
643
+ args (`PRMConfig`):
644
+ The arguments to use for training.
645
+ data_collator (`transformers.DataCollator`):
646
+ The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
647
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
648
+ train_dataset (`datasets.Dataset`):
649
+ The dataset to use for training.
650
+ eval_dataset (`datasets.Dataset`):
651
+ The dataset to use for evaluation.
652
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
653
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
654
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
655
+ reuse the fine-tuned model.
656
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
657
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
658
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
659
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
660
+ callbacks (`list[transformers.TrainerCallback]`):
661
+ The callbacks to use for training.
662
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
663
+ The optimizer and scheduler to use for training.
664
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
665
+ The function to use to preprocess the logits before computing the metrics.
666
+ peft_config (`dict`, defaults to `None`):
667
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
668
+
669
+ """
670
+ def __init__(
671
+ self,
672
+ model = None,
673
+ args = None,
674
+ data_collator = None,
675
+ train_dataset = None,
676
+ eval_dataset = None,
677
+ processing_class = None,
678
+ model_init = None,
679
+ compute_metrics = None,
680
+ callbacks = None,
681
+ preprocess_logits_for_metrics = None,
682
+ peft_config = None,
683
+ **kwargs
684
+ ):
685
+ if args is None: args = UnslothPRMConfig()
686
+ use_bf16 = getattr(args, 'bf16', False)
687
+ use_fp16 = getattr(args, 'fp16', False)
688
+ force_float32 = False
689
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
690
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
691
+ force_float32 = True
692
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
693
+ dtype = getattr(model.config, 'torch_dtype', None)
694
+ if dtype is None: dtype = model.get_input_embeddings().dtype
695
+ from unsloth_zoo.utils import _get_dtype
696
+ dtype = _get_dtype(dtype)
697
+ float16 = dtype == torch.float16
698
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
699
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
700
+ if force_float32:
701
+ args.fp16 = False
702
+ args.bf16 = False
703
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
704
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
705
+ args.fp16 = float16
706
+ args.bf16 = not float16
707
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
708
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
709
+ args.eval_strategy = 'steps'
710
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
711
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
712
+ if ga_steps is not None and ga_steps > 1:
713
+ from transformers import __version__ as transformers_version
714
+ if Version(transformers_version) <= Version('4.45.2'):
715
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
716
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
717
+ if getattr(args, 'eval_strategy', 'no') != 'no':
718
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
719
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
720
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
721
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
722
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
723
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
724
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
725
+ if force_float32:
726
+ args.bf16_full_eval = False
727
+ args.fp16_full_eval = False
728
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
729
+ args.bf16_full_eval = True
730
+ args.fp16_full_eval = False
731
+ elif not bf16_full_eval and not fp16_full_eval:
732
+ args.bf16_full_eval = args.bf16
733
+ args.fp16_full_eval = args.fp16
734
+ _output_logits = False
735
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
736
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
737
+ if _output_logits:
738
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
739
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
740
+ pass
741
+ else:
742
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
743
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
744
+ if args_max_seq_length is None and model_max_seq_length is not None:
745
+ max_seq_length = model.max_seq_length
746
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
747
+ if model is not None and hasattr(model, 'for_training'):
748
+ model.for_training()
749
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
750
+ if 'processing_class' in locals():
751
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
752
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
753
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
754
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
755
+ if not isinstance(data_collator, UnslothVisionDataCollator):
756
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
757
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
758
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
759
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
760
+ else:
761
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
762
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
763
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
764
+ if not isinstance(data_collator, UnslothVisionDataCollator):
765
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
766
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
767
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
768
+ else:
769
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
770
+ other_metrics = []
771
+
772
+ from unsloth_zoo.logging_utils import PatchRLStatistics
773
+ PatchRLStatistics('prm_trainer', other_metrics)
774
+
775
+ super().__init__(
776
+ model = model,
777
+ args = args,
778
+ data_collator = data_collator,
779
+ train_dataset = train_dataset,
780
+ eval_dataset = eval_dataset,
781
+ processing_class = processing_class,
782
+ model_init = model_init,
783
+ compute_metrics = compute_metrics,
784
+ callbacks = callbacks,
785
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
786
+ peft_config = peft_config,**kwargs)
787
+ if hasattr(self, 'neftune_hook_handle'):
788
+ self.neftune_hook_handle.remove()
789
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
790
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
791
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
792
+ pass
793
+
794
+ pass
unsloth_compiled_cache/UnslothRLOOTrainer.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRLOOConfig(RLOOConfig):
44
+ """
45
+
46
+ Configuration class for the [`RLOOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
58
+ Number of epochs to train.
59
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
60
+ Whether to whiten the rewards.
61
+ kl_coef (`float`, *optional*, defaults to `0.05`):
62
+ KL coefficient.
63
+ cliprange (`float`, *optional*, defaults to `0.2`):
64
+ Clip range.
65
+ rloo_k (`int`, *optional*, defaults to `2`):
66
+ REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
67
+ normalize_reward (`bool`, *optional*, defaults to `False`):
68
+ Whether to normalize rewards.
69
+ reward_clip_range (`float`, *optional*, defaults to `10.0`):
70
+ Clip range for rewards.
71
+ normalize_advantage (`bool`, *optional*, defaults to `False`):
72
+ Whether to normalize advantages.
73
+ token_level_kl (`bool`, *optional*, defaults to `True`):
74
+ Whether to use token-level KL penalty or sequence-level KL penalty.
75
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
76
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
77
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
78
+ capacity of a single GPU, albeit at the cost of slower generation.
79
+
80
+ """
81
+ vllm_sampling_params: Optional[Any] = field(
82
+ default = None,
83
+ metadata = {'help': 'vLLM SamplingParams'},
84
+ )
85
+ unsloth_num_chunks : Optional[int] = field(
86
+ default = -1,
87
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
+ )
89
+ def __init__(
90
+ self,
91
+ output_dir = None,
92
+ overwrite_output_dir = None,
93
+ do_train = False,
94
+ do_eval = False,
95
+ do_predict = False,
96
+ eval_strategy = 'no',
97
+ prediction_loss_only = False,
98
+ per_device_train_batch_size = 4,
99
+ per_device_eval_batch_size = 4,
100
+ per_gpu_train_batch_size = None,
101
+ per_gpu_eval_batch_size = None,
102
+ gradient_accumulation_steps = 2,
103
+ eval_accumulation_steps = 2,
104
+ eval_delay = 0,
105
+ torch_empty_cache_steps = 250,
106
+ learning_rate = 5e-05,
107
+ weight_decay = 0.01,
108
+ adam_beta1 = 0.9,
109
+ adam_beta2 = 0.999,
110
+ adam_epsilon = 1e-08,
111
+ max_grad_norm = 1.0,
112
+ num_train_epochs = 3.0,
113
+ max_steps = -1,
114
+ lr_scheduler_type = 'linear',
115
+ warmup_ratio = 0.1,
116
+ warmup_steps = 0,
117
+ log_level = 'passive',
118
+ log_level_replica = 'warning',
119
+ log_on_each_node = True,
120
+ logging_dir = None,
121
+ logging_strategy = 'steps',
122
+ logging_first_step = False,
123
+ logging_steps = 1,
124
+ logging_nan_inf_filter = False,
125
+ save_strategy = 'steps',
126
+ save_steps = 500,
127
+ save_total_limit = None,
128
+ save_safetensors = True,
129
+ save_on_each_node = False,
130
+ save_only_model = False,
131
+ restore_callback_states_from_checkpoint = False,
132
+ no_cuda = False,
133
+ use_cpu = False,
134
+ use_mps_device = False,
135
+ seed = 3407,
136
+ data_seed = 3407,
137
+ jit_mode_eval = False,
138
+ use_ipex = False,
139
+ bf16 = False,
140
+ fp16 = False,
141
+ fp16_opt_level = 'O1',
142
+ half_precision_backend = 'auto',
143
+ bf16_full_eval = False,
144
+ fp16_full_eval = False,
145
+ tf32 = None,
146
+ local_rank = -1,
147
+ ddp_backend = None,
148
+ tpu_num_cores = None,
149
+ tpu_metrics_debug = False,
150
+ debug = '',
151
+ dataloader_drop_last = False,
152
+ eval_steps = None,
153
+ dataloader_num_workers = 0,
154
+ dataloader_prefetch_factor = None,
155
+ past_index = -1,
156
+ run_name = None,
157
+ disable_tqdm = None,
158
+ remove_unused_columns = True,
159
+ label_names = None,
160
+ load_best_model_at_end = False,
161
+ metric_for_best_model = None,
162
+ greater_is_better = None,
163
+ ignore_data_skip = False,
164
+ fsdp = '',
165
+ fsdp_min_num_params = 0,
166
+ fsdp_config = None,
167
+ tp_size = 0,
168
+ fsdp_transformer_layer_cls_to_wrap = None,
169
+ accelerator_config = None,
170
+ deepspeed = None,
171
+ label_smoothing_factor = 0.0,
172
+ optim = 'adamw_8bit',
173
+ optim_args = None,
174
+ adafactor = False,
175
+ group_by_length = False,
176
+ length_column_name = 'length',
177
+ report_to = None,
178
+ ddp_find_unused_parameters = None,
179
+ ddp_bucket_cap_mb = None,
180
+ ddp_broadcast_buffers = None,
181
+ dataloader_pin_memory = True,
182
+ dataloader_persistent_workers = False,
183
+ skip_memory_metrics = True,
184
+ use_legacy_prediction_loop = False,
185
+ push_to_hub = False,
186
+ resume_from_checkpoint = None,
187
+ hub_model_id = None,
188
+ hub_strategy = 'every_save',
189
+ hub_token = None,
190
+ hub_private_repo = None,
191
+ hub_always_push = False,
192
+ gradient_checkpointing = False,
193
+ gradient_checkpointing_kwargs = None,
194
+ include_inputs_for_metrics = False,
195
+ eval_do_concat_batches = True,
196
+ fp16_backend = 'auto',
197
+ push_to_hub_model_id = None,
198
+ push_to_hub_organization = None,
199
+ push_to_hub_token = None,
200
+ mp_parameters = '',
201
+ auto_find_batch_size = False,
202
+ full_determinism = False,
203
+ torchdynamo = None,
204
+ ray_scope = 'last',
205
+ ddp_timeout = 1800,
206
+ torch_compile = False,
207
+ torch_compile_backend = None,
208
+ torch_compile_mode = None,
209
+ include_tokens_per_second = False,
210
+ include_num_input_tokens_seen = False,
211
+ neftune_noise_alpha = None,
212
+ optim_target_modules = None,
213
+ batch_eval_metrics = False,
214
+ eval_on_start = False,
215
+ use_liger_kernel = False,
216
+ eval_use_gather_object = False,
217
+ average_tokens_across_devices = False,
218
+ dataset_num_proc = None,
219
+ num_mini_batches = 1,
220
+ total_episodes = None,
221
+ local_rollout_forward_batch_size = 64,
222
+ num_sample_generations = 10,
223
+ response_length = 53,
224
+ stop_token = None,
225
+ stop_token_id = None,
226
+ temperature = 0.7,
227
+ missing_eos_penalty = None,
228
+ sft_model_path = 'EleutherAI/pythia-160m',
229
+ world_size = None,
230
+ num_total_batches = None,
231
+ micro_batch_size = None,
232
+ local_batch_size = None,
233
+ batch_size = None,
234
+ local_mini_batch_size = None,
235
+ mini_batch_size = None,
236
+ exp_name = 'rloo_config',
237
+ reward_model_path = 'EleutherAI/pythia-160m',
238
+ num_ppo_epochs = 4,
239
+ whiten_rewards = False,
240
+ kl_coef = 0.05,
241
+ cliprange = 0.2,
242
+ rloo_k = 2,
243
+ normalize_reward = False,
244
+ reward_clip_range = 10.0,
245
+ normalize_advantage = False,
246
+ token_level_kl = False,
247
+ ds3_gather_for_generation = True,
248
+ vllm_sampling_params = None,
249
+ unsloth_num_chunks = -1,
250
+ **kwargs,
251
+ ):
252
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
253
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
254
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
255
+ output_dir = 'unsloth_training_checkpoints'
256
+ save_strategy = 'no'
257
+ if dataset_num_proc is None:
258
+ from multiprocessing import cpu_count
259
+ dataset_num_proc = cpu_count()
260
+
261
+ super().__init__(
262
+ output_dir = output_dir,
263
+ overwrite_output_dir = overwrite_output_dir,
264
+ do_train = do_train,
265
+ do_eval = do_eval,
266
+ do_predict = do_predict,
267
+ eval_strategy = eval_strategy,
268
+ prediction_loss_only = prediction_loss_only,
269
+ per_device_train_batch_size = per_device_train_batch_size,
270
+ per_device_eval_batch_size = per_device_eval_batch_size,
271
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
272
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
273
+ gradient_accumulation_steps = gradient_accumulation_steps,
274
+ eval_accumulation_steps = eval_accumulation_steps,
275
+ eval_delay = eval_delay,
276
+ torch_empty_cache_steps = torch_empty_cache_steps,
277
+ learning_rate = learning_rate,
278
+ weight_decay = weight_decay,
279
+ adam_beta1 = adam_beta1,
280
+ adam_beta2 = adam_beta2,
281
+ adam_epsilon = adam_epsilon,
282
+ max_grad_norm = max_grad_norm,
283
+ num_train_epochs = num_train_epochs,
284
+ max_steps = max_steps,
285
+ lr_scheduler_type = lr_scheduler_type,
286
+ warmup_ratio = warmup_ratio,
287
+ warmup_steps = warmup_steps,
288
+ log_level = log_level,
289
+ log_level_replica = log_level_replica,
290
+ log_on_each_node = log_on_each_node,
291
+ logging_dir = logging_dir,
292
+ logging_strategy = logging_strategy,
293
+ logging_first_step = logging_first_step,
294
+ logging_steps = logging_steps,
295
+ logging_nan_inf_filter = logging_nan_inf_filter,
296
+ save_strategy = save_strategy,
297
+ save_steps = save_steps,
298
+ save_total_limit = save_total_limit,
299
+ save_safetensors = save_safetensors,
300
+ save_on_each_node = save_on_each_node,
301
+ save_only_model = save_only_model,
302
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
303
+ no_cuda = no_cuda,
304
+ use_cpu = use_cpu,
305
+ use_mps_device = use_mps_device,
306
+ seed = seed,
307
+ data_seed = data_seed,
308
+ jit_mode_eval = jit_mode_eval,
309
+ use_ipex = use_ipex,
310
+ bf16 = bf16,
311
+ fp16 = fp16,
312
+ fp16_opt_level = fp16_opt_level,
313
+ half_precision_backend = half_precision_backend,
314
+ bf16_full_eval = bf16_full_eval,
315
+ fp16_full_eval = fp16_full_eval,
316
+ tf32 = tf32,
317
+ local_rank = local_rank,
318
+ ddp_backend = ddp_backend,
319
+ tpu_num_cores = tpu_num_cores,
320
+ tpu_metrics_debug = tpu_metrics_debug,
321
+ debug = debug,
322
+ dataloader_drop_last = dataloader_drop_last,
323
+ eval_steps = eval_steps,
324
+ dataloader_num_workers = dataloader_num_workers,
325
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
326
+ past_index = past_index,
327
+ run_name = run_name,
328
+ disable_tqdm = disable_tqdm,
329
+ remove_unused_columns = remove_unused_columns,
330
+ label_names = label_names,
331
+ load_best_model_at_end = load_best_model_at_end,
332
+ metric_for_best_model = metric_for_best_model,
333
+ greater_is_better = greater_is_better,
334
+ ignore_data_skip = ignore_data_skip,
335
+ fsdp = fsdp,
336
+ fsdp_min_num_params = fsdp_min_num_params,
337
+ fsdp_config = fsdp_config,
338
+ tp_size = tp_size,
339
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
340
+ accelerator_config = accelerator_config,
341
+ deepspeed = deepspeed,
342
+ label_smoothing_factor = label_smoothing_factor,
343
+ optim = optim,
344
+ optim_args = optim_args,
345
+ adafactor = adafactor,
346
+ group_by_length = group_by_length,
347
+ length_column_name = length_column_name,
348
+ report_to = report_to,
349
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
350
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
351
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
352
+ dataloader_pin_memory = dataloader_pin_memory,
353
+ dataloader_persistent_workers = dataloader_persistent_workers,
354
+ skip_memory_metrics = skip_memory_metrics,
355
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
356
+ push_to_hub = push_to_hub,
357
+ resume_from_checkpoint = resume_from_checkpoint,
358
+ hub_model_id = hub_model_id,
359
+ hub_strategy = hub_strategy,
360
+ hub_token = hub_token,
361
+ hub_private_repo = hub_private_repo,
362
+ hub_always_push = hub_always_push,
363
+ gradient_checkpointing = gradient_checkpointing,
364
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
365
+ include_inputs_for_metrics = include_inputs_for_metrics,
366
+ eval_do_concat_batches = eval_do_concat_batches,
367
+ fp16_backend = fp16_backend,
368
+ push_to_hub_model_id = push_to_hub_model_id,
369
+ push_to_hub_organization = push_to_hub_organization,
370
+ push_to_hub_token = push_to_hub_token,
371
+ mp_parameters = mp_parameters,
372
+ auto_find_batch_size = auto_find_batch_size,
373
+ full_determinism = full_determinism,
374
+ torchdynamo = torchdynamo,
375
+ ray_scope = ray_scope,
376
+ ddp_timeout = ddp_timeout,
377
+ torch_compile = torch_compile,
378
+ torch_compile_backend = torch_compile_backend,
379
+ torch_compile_mode = torch_compile_mode,
380
+ include_tokens_per_second = include_tokens_per_second,
381
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
382
+ neftune_noise_alpha = neftune_noise_alpha,
383
+ optim_target_modules = optim_target_modules,
384
+ batch_eval_metrics = batch_eval_metrics,
385
+ eval_on_start = eval_on_start,
386
+ use_liger_kernel = use_liger_kernel,
387
+ eval_use_gather_object = eval_use_gather_object,
388
+ average_tokens_across_devices = average_tokens_across_devices,
389
+ dataset_num_proc = dataset_num_proc,
390
+ num_mini_batches = num_mini_batches,
391
+ total_episodes = total_episodes,
392
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
393
+ num_sample_generations = num_sample_generations,
394
+ response_length = response_length,
395
+ stop_token = stop_token,
396
+ stop_token_id = stop_token_id,
397
+ temperature = temperature,
398
+ missing_eos_penalty = missing_eos_penalty,
399
+ sft_model_path = sft_model_path,
400
+ world_size = world_size,
401
+ num_total_batches = num_total_batches,
402
+ micro_batch_size = micro_batch_size,
403
+ local_batch_size = local_batch_size,
404
+ batch_size = batch_size,
405
+ local_mini_batch_size = local_mini_batch_size,
406
+ mini_batch_size = mini_batch_size,
407
+ exp_name = exp_name,
408
+ reward_model_path = reward_model_path,
409
+ num_ppo_epochs = num_ppo_epochs,
410
+ whiten_rewards = whiten_rewards,
411
+ kl_coef = kl_coef,
412
+ cliprange = cliprange,
413
+ rloo_k = rloo_k,
414
+ normalize_reward = normalize_reward,
415
+ reward_clip_range = reward_clip_range,
416
+ normalize_advantage = normalize_advantage,
417
+ token_level_kl = token_level_kl,
418
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
419
+ self.vllm_sampling_params = vllm_sampling_params
420
+ self.unsloth_num_chunks = unsloth_num_chunks
421
+ pass
422
+
423
+ class _UnslothRLOOTrainer(Trainer):
424
+ _tag_names = ["trl", "rloo"]
425
+
426
+ def __init__(
427
+ self,
428
+ config: RLOOConfig,
429
+ processing_class: Optional[
430
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
431
+ ],
432
+ policy: nn.Module,
433
+ ref_policy: nn.Module,
434
+ reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
435
+ train_dataset: Dataset,
436
+ data_collator: Optional[DataCollatorWithPadding] = None,
437
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
438
+ # less commonly used
439
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
440
+ callbacks: Optional[list[TrainerCallback]] = None,
441
+ ) -> None:
442
+ if ref_policy is policy:
443
+ raise ValueError(
444
+ "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
445
+ "same as `policy`, you must mass a copy of it, or `None` if you use peft."
446
+ )
447
+
448
+ self.args = config
449
+ args = config
450
+ self.processing_class = processing_class
451
+ self.policy = policy
452
+
453
+ # Define the collator if not provided
454
+ if data_collator is None:
455
+ data_collator = DataCollatorWithPadding(self.processing_class)
456
+
457
+ self.policy.generation_config.eos_token_id = (
458
+ None # disable `pad_token_id` and `eos_token_id` because we just want to
459
+ )
460
+ self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
461
+
462
+ self.ref_policy = ref_policy
463
+ self.reward_model = reward_model
464
+ self.train_dataset = train_dataset
465
+ self.train_dataset_len = len(train_dataset)
466
+ self.data_collator = data_collator
467
+ self.eval_dataset = eval_dataset
468
+ self.optimizer, self.lr_scheduler = optimizers
469
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
470
+
471
+ #########
472
+ # calculate various batch sizes
473
+ #########
474
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
475
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
476
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
477
+ self.accelerator = accelerator
478
+ args.world_size = accelerator.num_processes
479
+ args.local_batch_size = (
480
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
481
+ )
482
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
483
+ args.batch_size = int(args.local_batch_size * args.world_size)
484
+ args.mini_batch_size = exact_div(
485
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
486
+ )
487
+ args.local_mini_batch_size = exact_div(
488
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
489
+ )
490
+ args.num_total_batches = math.ceil(
491
+ args.total_episodes / args.batch_size
492
+ ) # we may train for more than `total_episodes`
493
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
494
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
495
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
496
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
497
+ if args.num_sample_generations > 0:
498
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
499
+ self.local_dataloader_batch_size = exact_div(
500
+ args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
501
+ ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
502
+
503
+ #########
504
+ # setup model, optimizer, and others
505
+ #########
506
+ for module in [policy, ref_policy, reward_model]:
507
+ if isinstance(module, nn.Module):
508
+ disable_dropout_in_model(module)
509
+ if args.stop_token and args.stop_token == "eos":
510
+ args.stop_token_id = self.processing_class.eos_token_id
511
+ self.model = policy
512
+ self.create_optimizer_and_scheduler(
513
+ num_training_steps=args.num_total_batches
514
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
515
+
516
+ #########
517
+ ### trainer specifics
518
+ #########
519
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
520
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
521
+ self.callback_handler = CallbackHandler(
522
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
523
+ )
524
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
525
+ self.control = TrainerControl()
526
+ self.state = OnlineTrainerState(
527
+ is_local_process_zero=self.is_local_process_zero(),
528
+ is_world_process_zero=self.is_world_process_zero(),
529
+ stateful_callbacks=[
530
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
531
+ ],
532
+ )
533
+
534
+ self.current_flos = 0
535
+ self.hp_search_backend = None
536
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
537
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
538
+ # Create distant repo and output directory if needed
539
+ self.hub_model_id = None
540
+ if self.args.push_to_hub:
541
+ self.init_hf_repo()
542
+ if self.args.should_save:
543
+ os.makedirs(self.args.output_dir, exist_ok=True)
544
+ self.backup_model = None
545
+
546
+ # Add tags for models that have been loaded with the correct transformers version
547
+ if hasattr(self.model, "add_model_tags"):
548
+ self.model.add_model_tags(self._tag_names)
549
+
550
+ #########
551
+ ### setup dataloader
552
+ #########
553
+ self.dataloader = DataLoader(
554
+ self.train_dataset,
555
+ batch_size=self.local_dataloader_batch_size,
556
+ shuffle=True,
557
+ collate_fn=self.data_collator,
558
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
559
+ )
560
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
561
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
562
+ torch.manual_seed(args.seed)
563
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
564
+ torch.manual_seed(self.local_seed) # reset the local seed again
565
+
566
+ self.eval_dataloader = DataLoader(
567
+ self.eval_dataset,
568
+ batch_size=args.per_device_eval_batch_size,
569
+ collate_fn=self.data_collator,
570
+ drop_last=True,
571
+ ) # no need to shuffle eval dataset
572
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
573
+
574
+ if self.is_deepspeed_enabled:
575
+ if isinstance(self.reward_model, nn.Module):
576
+ self.reward_model = prepare_deepspeed(
577
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
578
+ )
579
+ self.ref_policy = prepare_deepspeed(
580
+ self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
581
+ )
582
+ self.deepspeed = self.model
583
+ else:
584
+ self.ref_policy = self.ref_policy.to(self.accelerator.device)
585
+ if isinstance(self.reward_model, nn.Module):
586
+ self.reward_model = self.reward_model.to(self.accelerator.device)
587
+
588
+ def get_train_dataloader(self) -> DataLoader:
589
+ return self.dataloader
590
+
591
+ def get_eval_dataloader(self) -> DataLoader:
592
+ return self.eval_dataloader
593
+
594
+ def train(self):
595
+ args = self.args
596
+ accelerator = self.accelerator
597
+ optimizer = self.optimizer
598
+ model = self.model
599
+ self.model_wrapped = self.model
600
+ ref_policy = self.ref_policy
601
+ reward_model = self.reward_model
602
+ processing_class = self.processing_class
603
+ dataloader = self.dataloader
604
+ device = accelerator.device
605
+
606
+ def repeat_generator():
607
+ while True:
608
+ yield from dataloader
609
+
610
+ iter_dataloader = iter(repeat_generator())
611
+ generation_config = GenerationConfig(
612
+ max_new_tokens=args.response_length,
613
+ temperature=(args.temperature + 1e-7),
614
+ top_k=0.0,
615
+ top_p=1.0,
616
+ do_sample=True,
617
+ )
618
+
619
+ accelerator.print("===training policy===")
620
+ start_time = time.time()
621
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
622
+ approxkl_stats = torch.zeros(stats_shape, device=device)
623
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
624
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
625
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
626
+ entropy_stats = torch.zeros(stats_shape, device=device)
627
+ ratio_stats = torch.zeros(stats_shape, device=device)
628
+ model.train()
629
+
630
+ # trainer state initialization
631
+ self.state.global_step = 0
632
+ self.state.episode = 0
633
+ self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
634
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
635
+ # Compute absolute values for logging, eval, and save if given as ratio
636
+ if args.logging_steps is not None:
637
+ if args.logging_steps < 1:
638
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
639
+ else:
640
+ self.state.logging_steps = args.logging_steps
641
+ if args.eval_steps is not None:
642
+ if args.eval_steps < 1:
643
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
644
+ else:
645
+ self.state.eval_steps = args.eval_steps
646
+ if args.save_steps is not None:
647
+ if args.save_steps < 1:
648
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
649
+ else:
650
+ self.state.save_steps = args.save_steps
651
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
652
+
653
+ for update in range(1, args.num_total_batches + 1):
654
+ self.state.episode += 1 * args.batch_size
655
+ data = next(iter_dataloader)
656
+ with torch.no_grad():
657
+ queries = data["input_ids"].to(device)
658
+ queries = queries.repeat(args.rloo_k, 1)
659
+ context_length = queries.shape[1]
660
+ responses = []
661
+ postprocessed_responses = []
662
+ logprobs = []
663
+ ref_logprobs = []
664
+ scores = []
665
+ sequence_lengths = []
666
+
667
+ # Generate responses and compute logprobs
668
+ with unwrap_model_for_generation(
669
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
670
+ ) as unwrapped_model:
671
+ query_responses, logitss = batch_generation(
672
+ unwrapped_model,
673
+ queries,
674
+ args.local_rollout_forward_batch_size,
675
+ processing_class.pad_token_id,
676
+ generation_config,
677
+ )
678
+
679
+ # Process responses in batches
680
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
681
+ query = queries[i : i + args.local_rollout_forward_batch_size]
682
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
683
+ response = query_response[:, context_length:]
684
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
685
+ logprob = selective_log_softmax(logits, response)
686
+ del logits
687
+ torch.cuda.empty_cache()
688
+
689
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
690
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
691
+ ref_logits /= args.temperature + 1e-7
692
+ ref_logprob = selective_log_softmax(ref_logits, response)
693
+ del ref_output, ref_logits
694
+ torch.cuda.empty_cache()
695
+
696
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
697
+ postprocessed_response = response
698
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
699
+ postprocessed_response = truncate_response(
700
+ args.stop_token_id, processing_class.pad_token_id, response
701
+ )
702
+
703
+ # Response Processing 2. run reward model on the truncated responses
704
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
705
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
706
+
707
+ if isinstance(reward_model, nn.Module):
708
+ _, score, _ = get_reward(
709
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
710
+ )
711
+ else:
712
+ score = torch.tensor(
713
+ reward_model(
714
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
715
+ ),
716
+ dtype=torch.float,
717
+ ).to(device)
718
+
719
+ # Store batch results
720
+ responses.append(response)
721
+ postprocessed_responses.append(postprocessed_response)
722
+ logprobs.append(logprob)
723
+ ref_logprobs.append(ref_logprob)
724
+ sequence_lengths.append(sequence_length)
725
+ scores.append(score)
726
+
727
+ # Concatenate all batched results
728
+ responses = torch.cat(responses, 0)
729
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
730
+ logprobs = torch.cat(logprobs, 0)
731
+ ref_logprobs = torch.cat(ref_logprobs, 0)
732
+ sequence_lengths = torch.cat(sequence_lengths, 0)
733
+ scores = torch.cat(scores, 0)
734
+ del (logprob, ref_logprob, score)
735
+ torch.cuda.empty_cache()
736
+ gc.collect()
737
+
738
+ # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
739
+ # responses not passing that filter will receive a low (fixed) score
740
+ # only query humans on responses that pass that filter
741
+ contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
742
+ if args.missing_eos_penalty is not None:
743
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
744
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
745
+
746
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
747
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
748
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
749
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
750
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
751
+
752
+ # 4. compute rewards
753
+ # Compute KL divergence
754
+ kl = logprobs - ref_logprobs
755
+
756
+ # Normalize rewards
757
+ if args.normalize_reward:
758
+ scores = (scores - scores.mean()) / (scores.std() + 1e-8)
759
+ scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
760
+
761
+ # Compute total reward with KL penalty
762
+ if args.token_level_kl:
763
+ # Token-level KL penalty: apply KL penalty per token
764
+ kl_reward = -args.kl_coef * kl
765
+
766
+ # Get the index of the last non-padded token for each sequence
767
+ eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
768
+ last_reward = torch.zeros_like(kl)
769
+ # Ensure scores has correct shape and type
770
+ scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
771
+ last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
772
+
773
+ # Combine KL reward and last reward
774
+ non_score_reward = kl_reward.sum(1) # Keep this for logging
775
+ reward = last_reward + kl_reward
776
+ rlhf_reward = reward.sum(1) # Sum across sequence length
777
+ else:
778
+ # Sequence-level KL penalty: sum KL across tokens first
779
+ sequence_kl = kl.sum(1)
780
+ non_score_reward = -args.kl_coef * sequence_kl
781
+ rlhf_reward = non_score_reward + scores
782
+
783
+ # vectorized RLOO advantages implementation
784
+ rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
785
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
786
+ advantages = rlhf_reward - baseline
787
+ advantages = advantages.flatten()
788
+
789
+ # Normalize advantages
790
+ if args.normalize_advantage:
791
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
792
+
793
+ torch.cuda.empty_cache()
794
+
795
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
796
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
797
+ b_inds = np.random.permutation(args.local_batch_size)
798
+ minibatch_idx = 0
799
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
800
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
801
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
802
+ gradient_accumulation_idx = 0
803
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
804
+ with accelerator.accumulate(model):
805
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
806
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
807
+
808
+ # Get batch data
809
+ mb_advantage = advantages[micro_batch_inds]
810
+ mb_responses = responses[micro_batch_inds]
811
+ mb_query_responses = query_responses[micro_batch_inds]
812
+ mb_logprobs = logprobs[micro_batch_inds]
813
+
814
+ # Forward pass
815
+ output = forward(model, mb_query_responses, processing_class.pad_token_id)
816
+ logits = output.logits[:, context_length - 1 : -1]
817
+ logits /= args.temperature + 1e-7
818
+
819
+ # Compute new logprobs
820
+ new_logprobs = selective_log_softmax(logits, mb_responses)
821
+ new_logprobs = torch.masked_fill(
822
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
823
+ )
824
+
825
+ # Compute probability ratios
826
+ new_ratio = (new_logprobs - mb_logprobs).exp()
827
+ new_logprobs = new_logprobs.sum(1)
828
+ mb_logprobs = mb_logprobs.sum(1)
829
+ logprobs_diff = new_logprobs - mb_logprobs
830
+ ratio = torch.exp(logprobs_diff)
831
+
832
+ # PPO clipped loss
833
+ pg_losses = -mb_advantage * ratio
834
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
835
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
836
+ pg_loss = pg_loss_max.mean()
837
+
838
+ # Final loss
839
+ loss = pg_loss
840
+
841
+ # Optimization step
842
+ accelerator.backward(loss)
843
+ optimizer.step()
844
+ optimizer.zero_grad()
845
+
846
+ with torch.no_grad():
847
+ pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
848
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
849
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
850
+ approxkl = 0.5 * (logprobs_diff**2).mean()
851
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
852
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
853
+ pg_clipfrac
854
+ )
855
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
856
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
857
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
858
+ gradient_accumulation_idx += 1
859
+ minibatch_idx += 1
860
+
861
+ # del everything and empty cache
862
+ # fmt: off
863
+ del (
864
+ output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
865
+ pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
866
+ mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
867
+ )
868
+ # fmt: on
869
+ torch.cuda.empty_cache()
870
+
871
+ # Compute metrics
872
+ with torch.no_grad():
873
+ mean_kl = kl.sum(1).mean()
874
+ mean_entropy = (-logprobs).sum(1).mean()
875
+ mean_non_score_reward = non_score_reward.mean()
876
+ eps = int(self.state.episode / (time.time() - start_time))
877
+ metrics = {}
878
+ metrics["eps"] = eps
879
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
880
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
881
+ metrics["objective/non_score_reward"] = (
882
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
883
+ )
884
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
885
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
886
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
887
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
888
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
889
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
890
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
891
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
892
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
893
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
894
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
895
+ metrics["episode"] = self.state.episode
896
+ self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
897
+ self.log(metrics)
898
+ del kl, mean_kl, mean_entropy, scores
899
+
900
+ self.lr_scheduler.step()
901
+ self.state.global_step += 1
902
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
903
+ if self.control.should_save:
904
+ self._save_checkpoint(model, trial=None)
905
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
906
+ torch.cuda.empty_cache()
907
+ gc.collect()
908
+
909
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
910
+ self.generate_completions(sampling=True)
911
+
912
+ # HF trainer specifics
913
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
914
+ if self.control.should_save:
915
+ self._save_checkpoint(model, trial=None, metrics=None)
916
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
917
+
918
+ def generate_completions(self, sampling: bool = False):
919
+ args = self.args
920
+ processing_class = self.processing_class
921
+ generation_config = GenerationConfig(
922
+ max_new_tokens=self.args.response_length,
923
+ temperature=(0.01 + 1e-7),
924
+ top_k=0.0,
925
+ top_p=1.0,
926
+ do_sample=True,
927
+ )
928
+
929
+ table = defaultdict(list)
930
+ with unwrap_model_for_generation(
931
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
932
+ ) as unwrapped_model:
933
+ for batch in self.eval_dataloader:
934
+ query = batch["input_ids"]
935
+ with torch.no_grad():
936
+ context_length = query.shape[1]
937
+ query_response, _ = batch_generation(
938
+ unwrapped_model,
939
+ query,
940
+ query.shape[0],
941
+ processing_class.pad_token_id,
942
+ generation_config,
943
+ )
944
+ response = query_response[:, context_length:]
945
+ postprocessed_response = response
946
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
947
+ postprocessed_response = truncate_response(
948
+ args.stop_token_id, processing_class.pad_token_id, response
949
+ )
950
+ table["query"].extend(
951
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
952
+ )
953
+ table["model response"].extend(
954
+ gather_object(processing_class.batch_decode(postprocessed_response))
955
+ )
956
+
957
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
958
+
959
+ if isinstance(self.reward_model, nn.Module):
960
+ _, score, _ = get_reward(
961
+ self.reward_model,
962
+ postprocessed_query_response,
963
+ processing_class.pad_token_id,
964
+ context_length,
965
+ )
966
+ else:
967
+ score = torch.tensor(
968
+ self.reward_model(
969
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
970
+ ),
971
+ dtype=torch.float,
972
+ ).to(postprocessed_query_response.device)
973
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
974
+
975
+ if sampling:
976
+ break
977
+ df = pd.DataFrame(table)
978
+
979
+ if self.accelerator.is_main_process:
980
+ print_rich_table(df.iloc[0 : 0 + 5])
981
+ if "wandb" in args.report_to:
982
+ import wandb
983
+
984
+ if wandb.run is not None:
985
+ wandb.log({"completions": wandb.Table(dataframe=df)})
986
+
987
+ if "comet_ml" in args.report_to:
988
+ log_table_to_comet_experiment(
989
+ name="completions.csv",
990
+ table=df,
991
+ )
992
+
993
+ def create_model_card(
994
+ self,
995
+ model_name: Optional[str] = None,
996
+ dataset_name: Optional[str] = None,
997
+ tags: Union[str, list[str], None] = None,
998
+ ):
999
+ """
1000
+ Creates a draft of a model card using the information available to the `Trainer`.
1001
+
1002
+ Args:
1003
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1004
+ Name of the model.
1005
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1006
+ Name of the dataset used for training.
1007
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1008
+ Tags to be associated with the model card.
1009
+ """
1010
+ if not self.is_world_process_zero():
1011
+ return
1012
+
1013
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1014
+ base_model = self.model.config._name_or_path
1015
+ else:
1016
+ base_model = None
1017
+
1018
+ tags = tags or []
1019
+ if isinstance(tags, str):
1020
+ tags = [tags]
1021
+
1022
+ if hasattr(self.model.config, "unsloth_version"):
1023
+ tags.append("unsloth")
1024
+
1025
+ citation = textwrap.dedent("""\
1026
+ @inproceedings{ahmadian2024back,
1027
+ title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1028
+ author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1029
+ year = 2024,
1030
+ booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1031
+ publisher = {Association for Computational Linguistics},
1032
+ pages = {12248--12267},
1033
+ editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1034
+ }""")
1035
+
1036
+ model_card = generate_model_card(
1037
+ base_model=base_model,
1038
+ model_name=model_name,
1039
+ hub_model_id=self.hub_model_id,
1040
+ dataset_name=dataset_name,
1041
+ tags=tags,
1042
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1043
+ comet_url=get_comet_experiment_url(),
1044
+ trainer_name="RLOO",
1045
+ trainer_citation=citation,
1046
+ paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1047
+ paper_id="2402.14740",
1048
+ )
1049
+
1050
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1051
+ class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1052
+ """
1053
+
1054
+ """
1055
+ def __init__(
1056
+ self,
1057
+ config,
1058
+ processing_class,
1059
+ policy,
1060
+ ref_policy,
1061
+ reward_model,
1062
+ train_dataset,
1063
+ data_collator = None,
1064
+ eval_dataset = None,
1065
+ callbacks = None,
1066
+ **kwargs
1067
+ ):
1068
+ if args is None: args = UnslothRLOOConfig()
1069
+ _output_logits = False
1070
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1071
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1072
+ if _output_logits:
1073
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1074
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1075
+ pass
1076
+ else:
1077
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1078
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1079
+ if args_max_seq_length is None and model_max_seq_length is not None:
1080
+ max_seq_length = model.max_seq_length
1081
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1082
+ if model is not None and hasattr(model, 'for_training'):
1083
+ model.for_training()
1084
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1085
+ if 'processing_class' in locals():
1086
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1087
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1088
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1089
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1090
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1091
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1092
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1093
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1094
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1095
+ else:
1096
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1097
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1098
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1099
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1100
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1101
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1102
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1103
+ else:
1104
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1105
+ other_metrics = []
1106
+
1107
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1108
+ PatchRLStatistics('rloo_trainer', other_metrics)
1109
+
1110
+ super().__init__(
1111
+ config = config,
1112
+ processing_class = processing_class,
1113
+ policy = policy,
1114
+ ref_policy = ref_policy,
1115
+ reward_model = reward_model,
1116
+ train_dataset = train_dataset,
1117
+ data_collator = data_collator,
1118
+ eval_dataset = eval_dataset,
1119
+ callbacks = callbacks,**kwargs)
1120
+ if hasattr(self, 'neftune_hook_handle'):
1121
+ self.neftune_hook_handle.remove()
1122
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1123
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1124
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1125
+ pass
1126
+
1127
+ pass
unsloth_compiled_cache/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRewardConfig(RewardConfig):
44
+ """
45
+
46
+ Configuration class for the [`RewardTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
55
+ limit. This argument is required if you want to use the default data collator.
56
+ disable_dropout (`bool`, *optional*, defaults to `True`):
57
+ Whether to disable dropout in the model.
58
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
59
+ Number of processes to use for processing the dataset.
60
+ center_rewards_coefficient (`float`, *optional*, defaults to `None`):
61
+ Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
62
+ https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
63
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
64
+ Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
65
+ the dataset is pretokenized.
66
+
67
+ """
68
+ vllm_sampling_params: Optional[Any] = field(
69
+ default = None,
70
+ metadata = {'help': 'vLLM SamplingParams'},
71
+ )
72
+ unsloth_num_chunks : Optional[int] = field(
73
+ default = -1,
74
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
75
+ )
76
+ def __init__(
77
+ self,
78
+ output_dir = None,
79
+ overwrite_output_dir = None,
80
+ do_train = False,
81
+ do_eval = False,
82
+ do_predict = False,
83
+ eval_strategy = 'no',
84
+ prediction_loss_only = False,
85
+ per_device_train_batch_size = 4,
86
+ per_device_eval_batch_size = 4,
87
+ per_gpu_train_batch_size = None,
88
+ per_gpu_eval_batch_size = None,
89
+ gradient_accumulation_steps = 2,
90
+ eval_accumulation_steps = 2,
91
+ eval_delay = 0,
92
+ torch_empty_cache_steps = 250,
93
+ learning_rate = 5e-05,
94
+ weight_decay = 0.01,
95
+ adam_beta1 = 0.9,
96
+ adam_beta2 = 0.999,
97
+ adam_epsilon = 1e-08,
98
+ max_grad_norm = 1.0,
99
+ num_train_epochs = 3.0,
100
+ max_steps = -1,
101
+ lr_scheduler_type = 'linear',
102
+ warmup_ratio = 0.1,
103
+ warmup_steps = 0,
104
+ log_level = 'passive',
105
+ log_level_replica = 'warning',
106
+ log_on_each_node = True,
107
+ logging_dir = None,
108
+ logging_strategy = 'steps',
109
+ logging_first_step = False,
110
+ logging_steps = 1,
111
+ logging_nan_inf_filter = False,
112
+ save_strategy = 'steps',
113
+ save_steps = 500,
114
+ save_total_limit = None,
115
+ save_safetensors = True,
116
+ save_on_each_node = False,
117
+ save_only_model = False,
118
+ restore_callback_states_from_checkpoint = False,
119
+ no_cuda = False,
120
+ use_cpu = False,
121
+ use_mps_device = False,
122
+ seed = 3407,
123
+ data_seed = 3407,
124
+ jit_mode_eval = False,
125
+ use_ipex = False,
126
+ bf16 = False,
127
+ fp16 = False,
128
+ fp16_opt_level = 'O1',
129
+ half_precision_backend = 'auto',
130
+ bf16_full_eval = False,
131
+ fp16_full_eval = False,
132
+ tf32 = None,
133
+ local_rank = -1,
134
+ ddp_backend = None,
135
+ tpu_num_cores = None,
136
+ tpu_metrics_debug = False,
137
+ debug = '',
138
+ dataloader_drop_last = False,
139
+ eval_steps = None,
140
+ dataloader_num_workers = 0,
141
+ dataloader_prefetch_factor = None,
142
+ past_index = -1,
143
+ run_name = None,
144
+ disable_tqdm = None,
145
+ remove_unused_columns = False,
146
+ label_names = None,
147
+ load_best_model_at_end = False,
148
+ metric_for_best_model = None,
149
+ greater_is_better = None,
150
+ ignore_data_skip = False,
151
+ fsdp = '',
152
+ fsdp_min_num_params = 0,
153
+ fsdp_config = None,
154
+ tp_size = 0,
155
+ fsdp_transformer_layer_cls_to_wrap = None,
156
+ accelerator_config = None,
157
+ deepspeed = None,
158
+ label_smoothing_factor = 0.0,
159
+ optim = 'adamw_8bit',
160
+ optim_args = None,
161
+ adafactor = False,
162
+ group_by_length = False,
163
+ length_column_name = 'length',
164
+ report_to = None,
165
+ ddp_find_unused_parameters = None,
166
+ ddp_bucket_cap_mb = None,
167
+ ddp_broadcast_buffers = None,
168
+ dataloader_pin_memory = True,
169
+ dataloader_persistent_workers = False,
170
+ skip_memory_metrics = True,
171
+ use_legacy_prediction_loop = False,
172
+ push_to_hub = False,
173
+ resume_from_checkpoint = None,
174
+ hub_model_id = None,
175
+ hub_strategy = 'every_save',
176
+ hub_token = None,
177
+ hub_private_repo = None,
178
+ hub_always_push = False,
179
+ gradient_checkpointing = False,
180
+ gradient_checkpointing_kwargs = None,
181
+ include_inputs_for_metrics = False,
182
+ eval_do_concat_batches = True,
183
+ fp16_backend = 'auto',
184
+ push_to_hub_model_id = None,
185
+ push_to_hub_organization = None,
186
+ push_to_hub_token = None,
187
+ mp_parameters = '',
188
+ auto_find_batch_size = False,
189
+ full_determinism = False,
190
+ torchdynamo = None,
191
+ ray_scope = 'last',
192
+ ddp_timeout = 1800,
193
+ torch_compile = False,
194
+ torch_compile_backend = None,
195
+ torch_compile_mode = None,
196
+ include_tokens_per_second = False,
197
+ include_num_input_tokens_seen = False,
198
+ neftune_noise_alpha = None,
199
+ optim_target_modules = None,
200
+ batch_eval_metrics = False,
201
+ eval_on_start = False,
202
+ use_liger_kernel = False,
203
+ eval_use_gather_object = False,
204
+ average_tokens_across_devices = False,
205
+ max_length = 1024,
206
+ disable_dropout = True,
207
+ dataset_num_proc = None,
208
+ center_rewards_coefficient = None,
209
+ vllm_sampling_params = None,
210
+ unsloth_num_chunks = -1,
211
+ **kwargs,
212
+ ):
213
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
214
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
215
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
216
+ output_dir = 'unsloth_training_checkpoints'
217
+ save_strategy = 'no'
218
+ if dataset_num_proc is None:
219
+ from multiprocessing import cpu_count
220
+ dataset_num_proc = cpu_count()
221
+
222
+ super().__init__(
223
+ output_dir = output_dir,
224
+ overwrite_output_dir = overwrite_output_dir,
225
+ do_train = do_train,
226
+ do_eval = do_eval,
227
+ do_predict = do_predict,
228
+ eval_strategy = eval_strategy,
229
+ prediction_loss_only = prediction_loss_only,
230
+ per_device_train_batch_size = per_device_train_batch_size,
231
+ per_device_eval_batch_size = per_device_eval_batch_size,
232
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
233
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
234
+ gradient_accumulation_steps = gradient_accumulation_steps,
235
+ eval_accumulation_steps = eval_accumulation_steps,
236
+ eval_delay = eval_delay,
237
+ torch_empty_cache_steps = torch_empty_cache_steps,
238
+ learning_rate = learning_rate,
239
+ weight_decay = weight_decay,
240
+ adam_beta1 = adam_beta1,
241
+ adam_beta2 = adam_beta2,
242
+ adam_epsilon = adam_epsilon,
243
+ max_grad_norm = max_grad_norm,
244
+ num_train_epochs = num_train_epochs,
245
+ max_steps = max_steps,
246
+ lr_scheduler_type = lr_scheduler_type,
247
+ warmup_ratio = warmup_ratio,
248
+ warmup_steps = warmup_steps,
249
+ log_level = log_level,
250
+ log_level_replica = log_level_replica,
251
+ log_on_each_node = log_on_each_node,
252
+ logging_dir = logging_dir,
253
+ logging_strategy = logging_strategy,
254
+ logging_first_step = logging_first_step,
255
+ logging_steps = logging_steps,
256
+ logging_nan_inf_filter = logging_nan_inf_filter,
257
+ save_strategy = save_strategy,
258
+ save_steps = save_steps,
259
+ save_total_limit = save_total_limit,
260
+ save_safetensors = save_safetensors,
261
+ save_on_each_node = save_on_each_node,
262
+ save_only_model = save_only_model,
263
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
264
+ no_cuda = no_cuda,
265
+ use_cpu = use_cpu,
266
+ use_mps_device = use_mps_device,
267
+ seed = seed,
268
+ data_seed = data_seed,
269
+ jit_mode_eval = jit_mode_eval,
270
+ use_ipex = use_ipex,
271
+ bf16 = bf16,
272
+ fp16 = fp16,
273
+ fp16_opt_level = fp16_opt_level,
274
+ half_precision_backend = half_precision_backend,
275
+ bf16_full_eval = bf16_full_eval,
276
+ fp16_full_eval = fp16_full_eval,
277
+ tf32 = tf32,
278
+ local_rank = local_rank,
279
+ ddp_backend = ddp_backend,
280
+ tpu_num_cores = tpu_num_cores,
281
+ tpu_metrics_debug = tpu_metrics_debug,
282
+ debug = debug,
283
+ dataloader_drop_last = dataloader_drop_last,
284
+ eval_steps = eval_steps,
285
+ dataloader_num_workers = dataloader_num_workers,
286
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
287
+ past_index = past_index,
288
+ run_name = run_name,
289
+ disable_tqdm = disable_tqdm,
290
+ remove_unused_columns = remove_unused_columns,
291
+ label_names = label_names,
292
+ load_best_model_at_end = load_best_model_at_end,
293
+ metric_for_best_model = metric_for_best_model,
294
+ greater_is_better = greater_is_better,
295
+ ignore_data_skip = ignore_data_skip,
296
+ fsdp = fsdp,
297
+ fsdp_min_num_params = fsdp_min_num_params,
298
+ fsdp_config = fsdp_config,
299
+ tp_size = tp_size,
300
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
301
+ accelerator_config = accelerator_config,
302
+ deepspeed = deepspeed,
303
+ label_smoothing_factor = label_smoothing_factor,
304
+ optim = optim,
305
+ optim_args = optim_args,
306
+ adafactor = adafactor,
307
+ group_by_length = group_by_length,
308
+ length_column_name = length_column_name,
309
+ report_to = report_to,
310
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
311
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
312
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
313
+ dataloader_pin_memory = dataloader_pin_memory,
314
+ dataloader_persistent_workers = dataloader_persistent_workers,
315
+ skip_memory_metrics = skip_memory_metrics,
316
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
317
+ push_to_hub = push_to_hub,
318
+ resume_from_checkpoint = resume_from_checkpoint,
319
+ hub_model_id = hub_model_id,
320
+ hub_strategy = hub_strategy,
321
+ hub_token = hub_token,
322
+ hub_private_repo = hub_private_repo,
323
+ hub_always_push = hub_always_push,
324
+ gradient_checkpointing = gradient_checkpointing,
325
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
326
+ include_inputs_for_metrics = include_inputs_for_metrics,
327
+ eval_do_concat_batches = eval_do_concat_batches,
328
+ fp16_backend = fp16_backend,
329
+ push_to_hub_model_id = push_to_hub_model_id,
330
+ push_to_hub_organization = push_to_hub_organization,
331
+ push_to_hub_token = push_to_hub_token,
332
+ mp_parameters = mp_parameters,
333
+ auto_find_batch_size = auto_find_batch_size,
334
+ full_determinism = full_determinism,
335
+ torchdynamo = torchdynamo,
336
+ ray_scope = ray_scope,
337
+ ddp_timeout = ddp_timeout,
338
+ torch_compile = torch_compile,
339
+ torch_compile_backend = torch_compile_backend,
340
+ torch_compile_mode = torch_compile_mode,
341
+ include_tokens_per_second = include_tokens_per_second,
342
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
343
+ neftune_noise_alpha = neftune_noise_alpha,
344
+ optim_target_modules = optim_target_modules,
345
+ batch_eval_metrics = batch_eval_metrics,
346
+ eval_on_start = eval_on_start,
347
+ use_liger_kernel = use_liger_kernel,
348
+ eval_use_gather_object = eval_use_gather_object,
349
+ average_tokens_across_devices = average_tokens_across_devices,
350
+ max_length = max_length,
351
+ disable_dropout = disable_dropout,
352
+ dataset_num_proc = dataset_num_proc,
353
+ center_rewards_coefficient = center_rewards_coefficient,**kwargs)
354
+ self.vllm_sampling_params = vllm_sampling_params
355
+ self.unsloth_num_chunks = unsloth_num_chunks
356
+ pass
357
+
358
+ class _UnslothRewardTrainer(Trainer):
359
+ _tag_names = ["trl", "reward-trainer"]
360
+
361
+ def __init__(
362
+ self,
363
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
364
+ args: Optional[RewardConfig] = None,
365
+ data_collator: Optional[DataCollator] = None,
366
+ train_dataset: Optional[Dataset] = None,
367
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
368
+ processing_class: Optional[
369
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
370
+ ] = None,
371
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
372
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
373
+ callbacks: Optional[list[TrainerCallback]] = None,
374
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
375
+ None,
376
+ None,
377
+ ),
378
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
379
+ peft_config: Optional[dict] = None,
380
+ ):
381
+ """
382
+ Initialize RewardTrainer.
383
+
384
+ Args:
385
+ model (`transformers.PreTrainedModel`):
386
+ The model to train, preferably an `AutoModelForSequenceClassification`.
387
+ args (`RewardConfig`):
388
+ The arguments to use for training.
389
+ data_collator (`transformers.DataCollator`):
390
+ The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
391
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
392
+ train_dataset (`datasets.Dataset`):
393
+ The dataset to use for training.
394
+ eval_dataset (`datasets.Dataset`):
395
+ The dataset to use for evaluation.
396
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
397
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
398
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
399
+ reuse the fine-tuned model.
400
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
401
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
402
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
403
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
404
+ callbacks (`list[transformers.TrainerCallback]`):
405
+ The callbacks to use for training.
406
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
407
+ The optimizer and scheduler to use for training.
408
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
409
+ The function to use to preprocess the logits before computing the metrics.
410
+ peft_config (`dict`, defaults to `None`):
411
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
412
+ """
413
+ if not is_peft_available() and peft_config is not None:
414
+ raise ValueError(
415
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
416
+ )
417
+ elif is_peft_available() and peft_config is not None:
418
+ if not isinstance(model, PeftModel):
419
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
420
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
421
+ inspect.signature(prepare_model_for_kbit_training).parameters
422
+ )
423
+
424
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
425
+
426
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
427
+ warnings.warn(
428
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
429
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
430
+ UserWarning,
431
+ )
432
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
433
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
434
+
435
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
436
+
437
+ model = model
438
+
439
+ # Disable dropout in the model
440
+ if args.disable_dropout:
441
+ disable_dropout_in_model(model)
442
+
443
+ if compute_metrics is None:
444
+ compute_metrics = compute_accuracy
445
+
446
+ if data_collator is None:
447
+ if processing_class is None:
448
+ raise ValueError(
449
+ "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
450
+ )
451
+
452
+ max_length = args.max_length
453
+
454
+ data_collator = RewardDataCollatorWithPadding(processing_class)
455
+
456
+ if args.remove_unused_columns:
457
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
458
+ args.remove_unused_columns = False
459
+ except FrozenInstanceError:
460
+ args = replace(args, remove_unused_columns=False)
461
+ # warn users
462
+ warnings.warn(
463
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
464
+ " we have set it for you, but you should do it yourself in the future.",
465
+ UserWarning,
466
+ )
467
+
468
+ self.use_reward_data_collator = True
469
+ else:
470
+ self.use_reward_data_collator = False
471
+
472
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
473
+ # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
474
+ # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
475
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
476
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
477
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
478
+ # issued.
479
+ model.warnings_issued["estimate_tokens"] = True
480
+
481
+ if "input_ids_chosen" not in train_dataset.column_names:
482
+ with PartialState().local_main_process_first():
483
+ fn_kwargs = {"tokenizer": processing_class}
484
+ train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
485
+ train_dataset = train_dataset.map(
486
+ _tokenize,
487
+ batched=True,
488
+ fn_kwargs=fn_kwargs,
489
+ num_proc=args.dataset_num_proc,
490
+ )
491
+ # This filter is important because otherwise you get samples that exceed the model's context length and
492
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
493
+ # user might get surprised if N samples are missing from training.
494
+ train_dataset = train_dataset.filter(
495
+ lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
496
+ num_proc=args.dataset_num_proc,
497
+ )
498
+ if eval_dataset is not None:
499
+ eval_dataset = eval_dataset.map(
500
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
501
+ )
502
+ eval_dataset = eval_dataset.map(
503
+ _tokenize,
504
+ fn_kwargs=fn_kwargs,
505
+ batched=True,
506
+ num_proc=args.dataset_num_proc,
507
+ )
508
+ # This filter is important because otherwise you get samples that exceed the model's context length and
509
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
510
+ # user might get surprised if N samples are missing from training.
511
+ eval_dataset = eval_dataset.filter(
512
+ lambda x: len(x["input_ids_chosen"]) <= max_length
513
+ and len(x["input_ids_rejected"]) <= max_length,
514
+ num_proc=args.dataset_num_proc,
515
+ )
516
+
517
+ super().__init__(
518
+ model=model,
519
+ args=args,
520
+ data_collator=data_collator,
521
+ train_dataset=train_dataset,
522
+ eval_dataset=eval_dataset,
523
+ processing_class=processing_class,
524
+ model_init=model_init,
525
+ compute_metrics=compute_metrics,
526
+ callbacks=callbacks,
527
+ optimizers=optimizers,
528
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
529
+ )
530
+
531
+ # Add tags for models that have been loaded with the correct transformers version
532
+ if hasattr(self.model, "add_model_tags"):
533
+ self.model.add_model_tags(self._tag_names)
534
+
535
+ def compute_loss(
536
+ self,
537
+ model: Union[PreTrainedModel, nn.Module],
538
+ inputs: dict[str, Union[torch.Tensor, Any]],
539
+ return_outputs=False,
540
+ num_items_in_batch=None,
541
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
542
+ rewards_chosen = model(
543
+ input_ids=inputs["input_ids_chosen"],
544
+ attention_mask=inputs["attention_mask_chosen"],
545
+ return_dict=True,
546
+ )["logits"]
547
+ rewards_rejected = model(
548
+ input_ids=inputs["input_ids_rejected"],
549
+ attention_mask=inputs["attention_mask_rejected"],
550
+ return_dict=True,
551
+ )["logits"]
552
+ # calculate loss, optionally modulate with margin
553
+ if "margin" in inputs:
554
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
555
+ else:
556
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
557
+
558
+ if self.args.center_rewards_coefficient is not None:
559
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
560
+
561
+ if return_outputs:
562
+ return loss, {
563
+ "rewards_chosen": rewards_chosen,
564
+ "rewards_rejected": rewards_rejected,
565
+ }
566
+ return loss
567
+
568
+ def prediction_step(
569
+ self,
570
+ model: Union[PreTrainedModel, nn.Module],
571
+ inputs: dict[str, Union[torch.Tensor, Any]],
572
+ prediction_loss_only: bool,
573
+ ignore_keys: Optional[list[str]] = None,
574
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
575
+ inputs = self._prepare_inputs(inputs)
576
+ if ignore_keys is None:
577
+ if hasattr(self.model, "config"):
578
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
579
+ else:
580
+ ignore_keys = []
581
+
582
+ with torch.no_grad():
583
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
584
+
585
+ if prediction_loss_only:
586
+ return (loss, None, None)
587
+
588
+ loss = loss.detach()
589
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
590
+ logits = nested_detach(logits)
591
+ # Stack accepted against rejected, mean over logits
592
+ # and softmax to get preferences between accepted and rejected to sum to 1
593
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
594
+
595
+ labels = torch.zeros(logits.shape[0])
596
+ labels = self._prepare_inputs(labels)
597
+
598
+ return loss, logits, labels
599
+
600
+ def evaluate(self, *args, **kwargs):
601
+ num_print_samples = kwargs.pop("num_print_samples", 4)
602
+ self.visualize_samples(num_print_samples)
603
+ return super().evaluate(*args, **kwargs)
604
+
605
+ def visualize_samples(self, num_print_samples: int):
606
+ """
607
+ Visualize the reward model logits prediction
608
+
609
+ Args:
610
+ num_print_samples (`int`, defaults to `4`):
611
+ The number of samples to print. Set to `-1` to print all samples.
612
+ """
613
+ eval_dataloader = self.get_eval_dataloader()
614
+ table = defaultdict(list)
615
+ for _, inputs in enumerate(eval_dataloader):
616
+ _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
617
+ chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
618
+ rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
619
+ table["chosen_text"].extend(gather_object(chosen_text))
620
+ table["rejected_text"].extend(gather_object(rejected_text))
621
+ table["logits"].extend(
622
+ gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
623
+ )
624
+ if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
625
+ break
626
+ df = pd.DataFrame(table)
627
+ if self.accelerator.process_index == 0:
628
+ print_rich_table(df[:num_print_samples])
629
+ if "wandb" in self.args.report_to:
630
+ import wandb
631
+
632
+ if wandb.run is not None:
633
+ wandb.log({"completions": wandb.Table(dataframe=df)})
634
+
635
+ if "comet_ml" in self.args.report_to:
636
+ log_table_to_comet_experiment(
637
+ name="completions.csv",
638
+ table=df,
639
+ )
640
+
641
+ def create_model_card(
642
+ self,
643
+ model_name: Optional[str] = None,
644
+ dataset_name: Optional[str] = None,
645
+ tags: Union[str, list[str], None] = None,
646
+ ):
647
+ """
648
+ Creates a draft of a model card using the information available to the `Trainer`.
649
+
650
+ Args:
651
+ model_name (`str` or `None`, *optional*, defaults to `None`):
652
+ Name of the model.
653
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
654
+ Name of the dataset used for training.
655
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
656
+ Tags to be associated with the model card.
657
+ """
658
+ if not self.is_world_process_zero():
659
+ return
660
+
661
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
662
+ base_model = self.model.config._name_or_path
663
+ else:
664
+ base_model = None
665
+
666
+ tags = tags or []
667
+ if isinstance(tags, str):
668
+ tags = [tags]
669
+
670
+ if hasattr(self.model.config, "unsloth_version"):
671
+ tags.append("unsloth")
672
+
673
+ model_card = generate_model_card(
674
+ base_model=base_model,
675
+ model_name=model_name,
676
+ hub_model_id=self.hub_model_id,
677
+ dataset_name=dataset_name,
678
+ tags=tags,
679
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
680
+ comet_url=get_comet_experiment_url(),
681
+ trainer_name="Reward",
682
+ )
683
+
684
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
685
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
686
+ """
687
+
688
+ """
689
+ def __init__(
690
+ self,
691
+ model = None,
692
+ args = None,
693
+ data_collator = None,
694
+ train_dataset = None,
695
+ eval_dataset = None,
696
+ processing_class = None,
697
+ model_init = None,
698
+ compute_metrics = None,
699
+ callbacks = None,
700
+ preprocess_logits_for_metrics = None,
701
+ peft_config = None,
702
+ **kwargs
703
+ ):
704
+ if args is None: args = UnslothRewardConfig()
705
+ use_bf16 = getattr(args, 'bf16', False)
706
+ use_fp16 = getattr(args, 'fp16', False)
707
+ force_float32 = False
708
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
709
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
710
+ force_float32 = True
711
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
712
+ dtype = getattr(model.config, 'torch_dtype', None)
713
+ if dtype is None: dtype = model.get_input_embeddings().dtype
714
+ from unsloth_zoo.utils import _get_dtype
715
+ dtype = _get_dtype(dtype)
716
+ float16 = dtype == torch.float16
717
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
718
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
719
+ if force_float32:
720
+ args.fp16 = False
721
+ args.bf16 = False
722
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
723
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
724
+ args.fp16 = float16
725
+ args.bf16 = not float16
726
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
727
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
728
+ args.eval_strategy = 'steps'
729
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
730
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
731
+ if ga_steps is not None and ga_steps > 1:
732
+ from transformers import __version__ as transformers_version
733
+ if Version(transformers_version) <= Version('4.45.2'):
734
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
735
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
736
+ if getattr(args, 'eval_strategy', 'no') != 'no':
737
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
738
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
739
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
740
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
741
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
742
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
743
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
744
+ if force_float32:
745
+ args.bf16_full_eval = False
746
+ args.fp16_full_eval = False
747
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
748
+ args.bf16_full_eval = True
749
+ args.fp16_full_eval = False
750
+ elif not bf16_full_eval and not fp16_full_eval:
751
+ args.bf16_full_eval = args.bf16
752
+ args.fp16_full_eval = args.fp16
753
+ _output_logits = False
754
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
755
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
756
+ if _output_logits:
757
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
758
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
759
+ pass
760
+ else:
761
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
762
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
763
+ if args_max_seq_length is None and model_max_seq_length is not None:
764
+ max_seq_length = model.max_seq_length
765
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
766
+ if model is not None and hasattr(model, 'for_training'):
767
+ model.for_training()
768
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
769
+ if 'processing_class' in locals():
770
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
771
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
772
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
773
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
774
+ if not isinstance(data_collator, UnslothVisionDataCollator):
775
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
776
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
777
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
778
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
779
+ else:
780
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
781
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
782
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
783
+ if not isinstance(data_collator, UnslothVisionDataCollator):
784
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
785
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
786
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
787
+ else:
788
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
789
+ other_metrics = []
790
+
791
+ from unsloth_zoo.logging_utils import PatchRLStatistics
792
+ PatchRLStatistics('reward_trainer', other_metrics)
793
+
794
+ super().__init__(
795
+ model = model,
796
+ args = args,
797
+ data_collator = data_collator,
798
+ train_dataset = train_dataset,
799
+ eval_dataset = eval_dataset,
800
+ processing_class = processing_class,
801
+ model_init = model_init,
802
+ compute_metrics = compute_metrics,
803
+ callbacks = callbacks,
804
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
805
+ peft_config = peft_config,**kwargs)
806
+ if hasattr(self, 'neftune_hook_handle'):
807
+ self.neftune_hook_handle.remove()
808
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
809
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
810
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
811
+ pass
812
+
813
+ pass
unsloth_compiled_cache/UnslothSFTTrainer.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothSFTConfig(SFTConfig):
44
+ """
45
+
46
+ Configuration class for the [`SFTTrainer`].
47
+
48
+ Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
49
+ [`~transformers.TrainingArguments`] documentation.
50
+
51
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
52
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
53
+ command line.
54
+
55
+ Parameters:
56
+ > Parameters that control the model
57
+
58
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
59
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
60
+ argument of the [`SFTTrainer`] is provided as a string.
61
+ use_liger (`bool`, *optional*, defaults to `False`):
62
+ Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
63
+
64
+ > Parameters that control the data preprocessing
65
+
66
+ dataset_text_field (`str`, *optional*, defaults to `"text"`):
67
+ Name of the column that contains text data in the dataset.
68
+ dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
+ Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
70
+ `skip_prepare_dataset`.
71
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
72
+ Number of processes to use for processing the dataset.
73
+ max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
74
+ Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
75
+ right.
76
+ If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
77
+ packing (`bool`, *optional*, defaults to `False`):
78
+ Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
79
+ length.
80
+ eval_packing (`bool` or `None`, *optional*, defaults to `None`):
81
+ Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
82
+
83
+ > Parameters that control the training
84
+
85
+ learning_rate (`float`, *optional*, defaults to `2e-5`):
86
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
87
+ [`~transformers.TrainingArguments`].
88
+
89
+ """
90
+ vllm_sampling_params: Optional[Any] = field(
91
+ default = None,
92
+ metadata = {'help': 'vLLM SamplingParams'},
93
+ )
94
+ unsloth_num_chunks : Optional[int] = field(
95
+ default = -1,
96
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
97
+ )
98
+ def __init__(
99
+ self,
100
+ output_dir = None,
101
+ overwrite_output_dir = None,
102
+ do_train = False,
103
+ do_eval = False,
104
+ do_predict = False,
105
+ eval_strategy = 'no',
106
+ prediction_loss_only = False,
107
+ per_device_train_batch_size = 4,
108
+ per_device_eval_batch_size = 4,
109
+ per_gpu_train_batch_size = None,
110
+ per_gpu_eval_batch_size = None,
111
+ gradient_accumulation_steps = 2,
112
+ eval_accumulation_steps = 2,
113
+ eval_delay = 0,
114
+ torch_empty_cache_steps = 250,
115
+ learning_rate = 5e-05,
116
+ weight_decay = 0.01,
117
+ adam_beta1 = 0.9,
118
+ adam_beta2 = 0.999,
119
+ adam_epsilon = 1e-08,
120
+ max_grad_norm = 1.0,
121
+ num_train_epochs = 3.0,
122
+ max_steps = -1,
123
+ lr_scheduler_type = 'linear',
124
+ warmup_ratio = 0.1,
125
+ warmup_steps = 0,
126
+ log_level = 'passive',
127
+ log_level_replica = 'warning',
128
+ log_on_each_node = True,
129
+ logging_dir = None,
130
+ logging_strategy = 'steps',
131
+ logging_first_step = False,
132
+ logging_steps = 1,
133
+ logging_nan_inf_filter = False,
134
+ save_strategy = 'steps',
135
+ save_steps = 500,
136
+ save_total_limit = None,
137
+ save_safetensors = True,
138
+ save_on_each_node = False,
139
+ save_only_model = False,
140
+ restore_callback_states_from_checkpoint = False,
141
+ no_cuda = False,
142
+ use_cpu = False,
143
+ use_mps_device = False,
144
+ seed = 3407,
145
+ data_seed = 3407,
146
+ jit_mode_eval = False,
147
+ use_ipex = False,
148
+ bf16 = False,
149
+ fp16 = False,
150
+ fp16_opt_level = 'O1',
151
+ half_precision_backend = 'auto',
152
+ bf16_full_eval = False,
153
+ fp16_full_eval = False,
154
+ tf32 = None,
155
+ local_rank = -1,
156
+ ddp_backend = None,
157
+ tpu_num_cores = None,
158
+ tpu_metrics_debug = False,
159
+ debug = '',
160
+ dataloader_drop_last = False,
161
+ eval_steps = None,
162
+ dataloader_num_workers = 0,
163
+ dataloader_prefetch_factor = None,
164
+ past_index = -1,
165
+ run_name = None,
166
+ disable_tqdm = None,
167
+ remove_unused_columns = True,
168
+ label_names = None,
169
+ load_best_model_at_end = False,
170
+ metric_for_best_model = None,
171
+ greater_is_better = None,
172
+ ignore_data_skip = False,
173
+ fsdp = '',
174
+ fsdp_min_num_params = 0,
175
+ fsdp_config = None,
176
+ tp_size = 0,
177
+ fsdp_transformer_layer_cls_to_wrap = None,
178
+ accelerator_config = None,
179
+ deepspeed = None,
180
+ label_smoothing_factor = 0.0,
181
+ optim = 'adamw_8bit',
182
+ optim_args = None,
183
+ adafactor = False,
184
+ group_by_length = False,
185
+ length_column_name = 'length',
186
+ report_to = None,
187
+ ddp_find_unused_parameters = None,
188
+ ddp_bucket_cap_mb = None,
189
+ ddp_broadcast_buffers = None,
190
+ dataloader_pin_memory = True,
191
+ dataloader_persistent_workers = False,
192
+ skip_memory_metrics = True,
193
+ use_legacy_prediction_loop = False,
194
+ push_to_hub = False,
195
+ resume_from_checkpoint = None,
196
+ hub_model_id = None,
197
+ hub_strategy = 'every_save',
198
+ hub_token = None,
199
+ hub_private_repo = None,
200
+ hub_always_push = False,
201
+ gradient_checkpointing = False,
202
+ gradient_checkpointing_kwargs = None,
203
+ include_inputs_for_metrics = False,
204
+ eval_do_concat_batches = True,
205
+ fp16_backend = 'auto',
206
+ push_to_hub_model_id = None,
207
+ push_to_hub_organization = None,
208
+ push_to_hub_token = None,
209
+ mp_parameters = '',
210
+ auto_find_batch_size = False,
211
+ full_determinism = False,
212
+ torchdynamo = None,
213
+ ray_scope = 'last',
214
+ ddp_timeout = 1800,
215
+ torch_compile = False,
216
+ torch_compile_backend = None,
217
+ torch_compile_mode = None,
218
+ include_tokens_per_second = False,
219
+ include_num_input_tokens_seen = False,
220
+ neftune_noise_alpha = None,
221
+ optim_target_modules = None,
222
+ batch_eval_metrics = False,
223
+ eval_on_start = False,
224
+ use_liger_kernel = False,
225
+ eval_use_gather_object = False,
226
+ average_tokens_across_devices = False,
227
+ model_init_kwargs = None,
228
+ use_liger = False,
229
+ dataset_text_field = 'text',
230
+ dataset_kwargs = None,
231
+ dataset_num_proc = None,
232
+ max_seq_length = None,
233
+ packing = False,
234
+ eval_packing = None,
235
+ dataset_batch_size = None,
236
+ num_of_sequences = None,
237
+ chars_per_token = None,
238
+ vllm_sampling_params = None,
239
+ unsloth_num_chunks = -1,
240
+ **kwargs,
241
+ ):
242
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
243
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
244
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
245
+ output_dir = 'unsloth_training_checkpoints'
246
+ save_strategy = 'no'
247
+ if dataset_num_proc is None:
248
+ from multiprocessing import cpu_count
249
+ dataset_num_proc = cpu_count()
250
+
251
+ super().__init__(
252
+ output_dir = output_dir,
253
+ overwrite_output_dir = overwrite_output_dir,
254
+ do_train = do_train,
255
+ do_eval = do_eval,
256
+ do_predict = do_predict,
257
+ eval_strategy = eval_strategy,
258
+ prediction_loss_only = prediction_loss_only,
259
+ per_device_train_batch_size = per_device_train_batch_size,
260
+ per_device_eval_batch_size = per_device_eval_batch_size,
261
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
262
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
263
+ gradient_accumulation_steps = gradient_accumulation_steps,
264
+ eval_accumulation_steps = eval_accumulation_steps,
265
+ eval_delay = eval_delay,
266
+ torch_empty_cache_steps = torch_empty_cache_steps,
267
+ learning_rate = learning_rate,
268
+ weight_decay = weight_decay,
269
+ adam_beta1 = adam_beta1,
270
+ adam_beta2 = adam_beta2,
271
+ adam_epsilon = adam_epsilon,
272
+ max_grad_norm = max_grad_norm,
273
+ num_train_epochs = num_train_epochs,
274
+ max_steps = max_steps,
275
+ lr_scheduler_type = lr_scheduler_type,
276
+ warmup_ratio = warmup_ratio,
277
+ warmup_steps = warmup_steps,
278
+ log_level = log_level,
279
+ log_level_replica = log_level_replica,
280
+ log_on_each_node = log_on_each_node,
281
+ logging_dir = logging_dir,
282
+ logging_strategy = logging_strategy,
283
+ logging_first_step = logging_first_step,
284
+ logging_steps = logging_steps,
285
+ logging_nan_inf_filter = logging_nan_inf_filter,
286
+ save_strategy = save_strategy,
287
+ save_steps = save_steps,
288
+ save_total_limit = save_total_limit,
289
+ save_safetensors = save_safetensors,
290
+ save_on_each_node = save_on_each_node,
291
+ save_only_model = save_only_model,
292
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
293
+ no_cuda = no_cuda,
294
+ use_cpu = use_cpu,
295
+ use_mps_device = use_mps_device,
296
+ seed = seed,
297
+ data_seed = data_seed,
298
+ jit_mode_eval = jit_mode_eval,
299
+ use_ipex = use_ipex,
300
+ bf16 = bf16,
301
+ fp16 = fp16,
302
+ fp16_opt_level = fp16_opt_level,
303
+ half_precision_backend = half_precision_backend,
304
+ bf16_full_eval = bf16_full_eval,
305
+ fp16_full_eval = fp16_full_eval,
306
+ tf32 = tf32,
307
+ local_rank = local_rank,
308
+ ddp_backend = ddp_backend,
309
+ tpu_num_cores = tpu_num_cores,
310
+ tpu_metrics_debug = tpu_metrics_debug,
311
+ debug = debug,
312
+ dataloader_drop_last = dataloader_drop_last,
313
+ eval_steps = eval_steps,
314
+ dataloader_num_workers = dataloader_num_workers,
315
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
316
+ past_index = past_index,
317
+ run_name = run_name,
318
+ disable_tqdm = disable_tqdm,
319
+ remove_unused_columns = remove_unused_columns,
320
+ label_names = label_names,
321
+ load_best_model_at_end = load_best_model_at_end,
322
+ metric_for_best_model = metric_for_best_model,
323
+ greater_is_better = greater_is_better,
324
+ ignore_data_skip = ignore_data_skip,
325
+ fsdp = fsdp,
326
+ fsdp_min_num_params = fsdp_min_num_params,
327
+ fsdp_config = fsdp_config,
328
+ tp_size = tp_size,
329
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
330
+ accelerator_config = accelerator_config,
331
+ deepspeed = deepspeed,
332
+ label_smoothing_factor = label_smoothing_factor,
333
+ optim = optim,
334
+ optim_args = optim_args,
335
+ adafactor = adafactor,
336
+ group_by_length = group_by_length,
337
+ length_column_name = length_column_name,
338
+ report_to = report_to,
339
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
340
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
341
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
342
+ dataloader_pin_memory = dataloader_pin_memory,
343
+ dataloader_persistent_workers = dataloader_persistent_workers,
344
+ skip_memory_metrics = skip_memory_metrics,
345
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
346
+ push_to_hub = push_to_hub,
347
+ resume_from_checkpoint = resume_from_checkpoint,
348
+ hub_model_id = hub_model_id,
349
+ hub_strategy = hub_strategy,
350
+ hub_token = hub_token,
351
+ hub_private_repo = hub_private_repo,
352
+ hub_always_push = hub_always_push,
353
+ gradient_checkpointing = gradient_checkpointing,
354
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
355
+ include_inputs_for_metrics = include_inputs_for_metrics,
356
+ eval_do_concat_batches = eval_do_concat_batches,
357
+ fp16_backend = fp16_backend,
358
+ push_to_hub_model_id = push_to_hub_model_id,
359
+ push_to_hub_organization = push_to_hub_organization,
360
+ push_to_hub_token = push_to_hub_token,
361
+ mp_parameters = mp_parameters,
362
+ auto_find_batch_size = auto_find_batch_size,
363
+ full_determinism = full_determinism,
364
+ torchdynamo = torchdynamo,
365
+ ray_scope = ray_scope,
366
+ ddp_timeout = ddp_timeout,
367
+ torch_compile = torch_compile,
368
+ torch_compile_backend = torch_compile_backend,
369
+ torch_compile_mode = torch_compile_mode,
370
+ include_tokens_per_second = include_tokens_per_second,
371
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
372
+ neftune_noise_alpha = neftune_noise_alpha,
373
+ optim_target_modules = optim_target_modules,
374
+ batch_eval_metrics = batch_eval_metrics,
375
+ eval_on_start = eval_on_start,
376
+ use_liger_kernel = use_liger_kernel,
377
+ eval_use_gather_object = eval_use_gather_object,
378
+ average_tokens_across_devices = average_tokens_across_devices,
379
+ model_init_kwargs = model_init_kwargs,
380
+ use_liger = use_liger,
381
+ dataset_text_field = dataset_text_field,
382
+ dataset_kwargs = dataset_kwargs,
383
+ dataset_num_proc = dataset_num_proc,
384
+ max_seq_length = max_seq_length,
385
+ packing = packing,
386
+ eval_packing = eval_packing,
387
+ dataset_batch_size = dataset_batch_size,
388
+ num_of_sequences = num_of_sequences,
389
+ chars_per_token = chars_per_token,**kwargs)
390
+ self.vllm_sampling_params = vllm_sampling_params
391
+ self.unsloth_num_chunks = unsloth_num_chunks
392
+ pass
393
+
394
+ class _UnslothSFTTrainer(Trainer):
395
+ """"""
396
+
397
+ _tag_names = ["trl", "sft"]
398
+
399
+ @deprecate_kwarg(
400
+ "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
401
+ )
402
+ def __init__(
403
+ self,
404
+ model: Union[str, nn.Module, PreTrainedModel],
405
+ args: Optional[Union[SFTConfig, TrainingArguments]] = None,
406
+ data_collator: Optional[DataCollator] = None, # type: ignore
407
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
408
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
409
+ processing_class: Optional[
410
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
411
+ ] = None,
412
+ compute_loss_func: Optional[Callable] = None,
413
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
414
+ callbacks: Optional[list[TrainerCallback]] = None,
415
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
416
+ optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
417
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
418
+ peft_config: Optional["PeftConfig"] = None,
419
+ formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
420
+ ):
421
+ # Args
422
+ if args is None:
423
+ model_name = model if isinstance(model, str) else model.config._name_or_path
424
+ model_name = model_name.split("/")[-1]
425
+ args = SFTConfig(f"{model_name}-SFT")
426
+ elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
427
+ dict_args = args.to_dict()
428
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
429
+ dict_args.pop("push_to_hub_token")
430
+ args = SFTConfig(**dict_args)
431
+
432
+ # Model
433
+ if args.model_init_kwargs is not None and not isinstance(model, str):
434
+ warnings.warn(
435
+ "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
436
+ "The `model_init_kwargs` will be ignored."
437
+ )
438
+ if isinstance(model, str):
439
+ model = self._create_model_from_path(model, args)
440
+
441
+ # PEFT configuration and model wrapping
442
+ if False:
443
+ model = self._prepare_peft_model(model, peft_config, args)
444
+
445
+ # Handle the tokenizer
446
+ if processing_class is None:
447
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
448
+ if processing_class.pad_token is None:
449
+ processing_class.pad_token = processing_class.eos_token # required for padding when collating data
450
+
451
+ # Dataset
452
+ preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
453
+ if preprocess_dataset:
454
+ train_dataset = self._prepare_dataset(
455
+ train_dataset, processing_class, args, args.packing, formatting_func, "train"
456
+ )
457
+ if eval_dataset is not None:
458
+ packing = args.packing if args.eval_packing is None else args.eval_packing
459
+ if isinstance(eval_dataset, dict):
460
+ eval_dataset = {
461
+ key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
462
+ for key, dataset in eval_dataset.items()
463
+ }
464
+ else:
465
+ eval_dataset = self._prepare_dataset(
466
+ eval_dataset, processing_class, args, packing, formatting_func, "eval"
467
+ )
468
+
469
+ # Data collator
470
+ if data_collator is None:
471
+ data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
472
+
473
+ # Initialize the metrics
474
+ self._metrics = defaultdict(list)
475
+
476
+ # Initialize the Trainer. Parent class will handle:
477
+ # - DeepSpeed configuration (through create_accelerator_and_postprocess)
478
+ # - FSDP setup
479
+ # - Distributed training setup
480
+ # - Optimizer and scheduler creation
481
+ # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
482
+ super_init_kwargs = {}
483
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
484
+ super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
485
+ else:
486
+ if optimizer_cls_and_kwargs is not None:
487
+ warnings.warn(
488
+ "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
489
+ "The default optimizer will be used. "
490
+ "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
491
+ )
492
+ super().__init__(
493
+ model=model,
494
+ args=args,
495
+ data_collator=data_collator,
496
+ train_dataset=train_dataset,
497
+ eval_dataset=eval_dataset,
498
+ processing_class=processing_class,
499
+ compute_loss_func=compute_loss_func,
500
+ compute_metrics=compute_metrics,
501
+ callbacks=callbacks,
502
+ optimizers=optimizers,
503
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
504
+ **super_init_kwargs,
505
+ )
506
+
507
+ # Add tags for models that have been loaded with the correct transformers version
508
+ if hasattr(self.model, "add_model_tags"):
509
+ self.model.add_model_tags(self._tag_names)
510
+
511
+ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
512
+ """Creates a model from a path or model identifier."""
513
+ model_init_kwargs = args.model_init_kwargs or {}
514
+ # Handle torch dtype
515
+ torch_dtype = model_init_kwargs.get("torch_dtype")
516
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
517
+ pass # torch_dtype is already a torch.dtype or "auto" or None
518
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
519
+ torch_dtype = getattr(torch, torch_dtype)
520
+ model_init_kwargs["torch_dtype"] = torch_dtype
521
+ else:
522
+ raise ValueError(
523
+ "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
524
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
525
+ )
526
+ # Disable caching if gradient checkpointing is enabled (not supported)
527
+ if args.gradient_checkpointing:
528
+ model_init_kwargs["use_cache"] = False
529
+
530
+ # Create model
531
+ if args.use_liger:
532
+ if not is_liger_kernel_available():
533
+ raise ImportError("Please install Liger-kernel for use_liger=True")
534
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
535
+ else:
536
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
537
+ return model
538
+
539
+ def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
540
+ """Prepares a model for PEFT training."""
541
+ if not is_peft_available():
542
+ raise ImportError("To use PeftModel, you need to install the `peft` library.")
543
+
544
+ if not isinstance(peft_config, PeftConfig):
545
+ raise ValueError(
546
+ f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
547
+ "to pass a PeftConfig object to the SFTTrainer."
548
+ )
549
+
550
+ if isinstance(model, PeftModel):
551
+ return model
552
+
553
+ # Handle quantized models (QLoRA)
554
+ is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
555
+
556
+ is_sharded_qlora = False
557
+ if getattr(model, "is_loaded_in_4bit", False):
558
+ # Check if model is sharded (FSDP/DS-Zero3)
559
+ for _, param in model.named_parameters():
560
+ if param.__class__.__name__ == "Params4bit":
561
+ is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
562
+ break
563
+
564
+ # Prepare model for kbit training if needed
565
+ if is_qlora and not is_sharded_qlora:
566
+ model = self._prepare_model_for_kbit_training(model, args)
567
+ # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
568
+ args = dataclasses.replace(args, gradient_checkpointing=False)
569
+ elif args.gradient_checkpointing:
570
+ model = self._enable_gradient_checkpointing(model, args)
571
+
572
+ # Create PEFT model
573
+ if (
574
+ version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
575
+ and getattr(model, "is_loaded_in_4bit", False)
576
+ and is_sharded_qlora
577
+ ):
578
+ model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
579
+ else:
580
+ model = get_peft_model(model, peft_config)
581
+
582
+ # Handle bf16 casting for 4-bit models
583
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
584
+ peft_module_casting_to_bf16(model)
585
+
586
+ return model
587
+
588
+ def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
589
+ """Prepares a quantized model for kbit training."""
590
+ prepare_model_kwargs = {
591
+ "use_gradient_checkpointing": args.gradient_checkpointing,
592
+ "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
593
+ }
594
+
595
+ return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
596
+
597
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
598
+ """Enables gradient checkpointing for the model."""
599
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
600
+ use_reentrant = (
601
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
602
+ )
603
+
604
+ if use_reentrant:
605
+ if hasattr(model, "enable_input_require_grads"):
606
+ model.enable_input_require_grads()
607
+ else:
608
+
609
+ def make_inputs_require_grad(module, input, output):
610
+ output.requires_grad_(True)
611
+
612
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
613
+
614
+ return model
615
+
616
+ def _prepare_dataset(
617
+ self,
618
+ dataset: Union[Dataset, IterableDataset],
619
+ processing_class,
620
+ args,
621
+ packing: bool,
622
+ formatting_func: Optional[Callable[[dict], str]],
623
+ dataset_name: str,
624
+ ) -> Union[Dataset, IterableDataset]:
625
+ # All Unsloth Zoo code licensed under LGPLv3
626
+ if isinstance(dataset, ConstantLengthDataset): return dataset
627
+
628
+ map_kwargs = {}
629
+ use_desc = isinstance(dataset, Dataset)
630
+ is_vlm = hasattr(processing_class, "tokenizer")
631
+ tokenizer = processing_class
632
+ if is_vlm: tokenizer = processing_class.tokenizer
633
+
634
+ # Get max length
635
+ max_seq_length = getattr(args, "max_length", 0)
636
+ if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
637
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
638
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
639
+ if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
640
+ dataset_text_field = getattr(args, "dataset_text_field", "text")
641
+ do_truncation = max_seq_length != 0
642
+ do_formatting_func = False
643
+ do_tokenize = True
644
+
645
+ # Get correct column names
646
+ column_names = set(next(iter(dataset)).keys())
647
+ used_column_names = ["input_ids"]
648
+ if "attention_mask" in column_names:
649
+ used_column_names.append("attention_mask")
650
+
651
+ # Check if already tokenized so skip
652
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
653
+ if "labels" in column_names:
654
+ # Most likely forgot data collator!
655
+ if is_vlm and not hasattr(tokenizer, "pad"):
656
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
657
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
658
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer)
659
+ used_column_names.append("labels")
660
+ do_tokenize = False
661
+ elif "input_ids" in column_names:
662
+ # Skip dataset prep, and set data collator
663
+ if is_vlm and not hasattr(tokenizer, "pad"):
664
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
665
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
666
+ self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
667
+ do_tokenize = False
668
+ elif dataset_text_field not in column_names:
669
+ do_formatting_func = True
670
+ if formatting_func is None:
671
+ raise RuntimeError("Unsloth: You must specify a `formatting_func`")
672
+ pass
673
+
674
+ if do_tokenize:
675
+ # Check double BOS tokens
676
+ if do_formatting_func:
677
+ test_text = formatting_func(next(iter(dataset)))
678
+ if not isinstance(test_text, list):
679
+ raise ValueError(
680
+ "Unsloth: The `formatting_func` should return a list of processed strings."
681
+ )
682
+ test_text = test_text[0]
683
+ else:
684
+ test_text = next(iter(dataset))[dataset_text_field][0]
685
+
686
+ # Get chat template
687
+ chat_template = getattr(processing_class, 'chat_template', '')
688
+ if chat_template == '' and is_vlm:
689
+ chat_template = getattr(tokenizer, 'chat_template', '')
690
+ if chat_template is None:
691
+ chat_template = ''
692
+
693
+ # Get bos_token
694
+ add_special_tokens = True
695
+ bos_token_1 = getattr(processing_class, 'bos_token', None)
696
+ bos_token_2 = getattr(tokenizer, 'bos_token', None)
697
+ bos_token = bos_token_1 or bos_token_2
698
+
699
+ if bos_token is not None:
700
+ if test_text.startswith(bos_token) or bos_token in chat_template:
701
+ add_special_tokens = False
702
+ print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
703
+ pass
704
+
705
+ # Create tokenize function
706
+ def _tokenize(example):
707
+ return tokenizer(
708
+ example[dataset_text_field] if not do_formatting_func else formatting_func(example),
709
+ truncation = do_truncation,
710
+ max_length = max_seq_length,
711
+ return_token_type_ids = False,
712
+ add_special_tokens = add_special_tokens,
713
+ )
714
+ pass
715
+
716
+ if not isinstance(dataset, IterableDataset):
717
+ map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
718
+ else:
719
+ map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
720
+
721
+ if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
722
+ dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
723
+
724
+ # If VLM, switch data collator since .pad is needed!
725
+ if is_vlm and not hasattr(processing_class, "pad"):
726
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
727
+ self.data_collator = data_collator
728
+ pass
729
+ pass
730
+ if packing:
731
+ print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
732
+ return dataset
733
+
734
+ if max_seq_length == 0:
735
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
736
+
737
+ if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
738
+ dataset = dataset.select_columns(used_column_names).map(
739
+ pack_examples,
740
+ batched = True,
741
+ fn_kwargs = {"seq_length": max_seq_length,},
742
+ **map_kwargs,
743
+ )
744
+ pass
745
+ return dataset
746
+
747
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
748
+ outputs = super().compute_loss(
749
+ model,
750
+ inputs,
751
+ return_outputs = return_outputs,
752
+ num_items_in_batch = num_items_in_batch,
753
+ )
754
+ return outputs
755
+
756
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
757
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
758
+
759
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
760
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
761
+ if next(iter(logs.keys())).startswith("eval_"):
762
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
763
+
764
+ logs = {**logs, **metrics}
765
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
766
+ super().log(logs, start_time)
767
+ else: # transformers<=4.46
768
+ super().log(logs)
769
+ self._metrics.clear()
770
+
771
+ def create_model_card(
772
+ self,
773
+ model_name: Optional[str] = None,
774
+ dataset_name: Optional[str] = None,
775
+ tags: Union[str, list[str], None] = None,
776
+ ):
777
+ """
778
+ Creates a draft of a model card using the information available to the `Trainer`.
779
+
780
+ Args:
781
+ model_name (`str` or `None`, *optional*, defaults to `None`):
782
+ Name of the model.
783
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
784
+ Name of the dataset used for training.
785
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
786
+ Tags to be associated with the model card.
787
+ """
788
+ if not self.is_world_process_zero():
789
+ return
790
+
791
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
792
+ base_model = self.model.config._name_or_path
793
+ else:
794
+ base_model = None
795
+
796
+ tags = tags or []
797
+ if isinstance(tags, str):
798
+ tags = [tags]
799
+
800
+ if hasattr(self.model.config, "unsloth_version"):
801
+ tags.append("unsloth")
802
+
803
+ model_card = generate_model_card(
804
+ base_model=base_model,
805
+ model_name=model_name,
806
+ hub_model_id=self.hub_model_id,
807
+ dataset_name=dataset_name,
808
+ tags=tags,
809
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
810
+ comet_url=get_comet_experiment_url(),
811
+ trainer_name="SFT",
812
+ )
813
+
814
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
815
+ class UnslothSFTTrainer(_UnslothSFTTrainer):
816
+ """
817
+
818
+ Trainer for Supervised Fine-Tuning (SFT) method.
819
+
820
+ This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
821
+
822
+ Example:
823
+
824
+ ```python
825
+ from datasets import load_dataset
826
+ from trl import SFTTrainer
827
+
828
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
829
+
830
+ trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
831
+ trainer.train()
832
+ ```
833
+
834
+ Args:
835
+ model (`Union[str, PreTrainedModel]`):
836
+ Model to be trained. Can be either:
837
+
838
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
839
+ a path to a *directory* containing model weights saved using
840
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
841
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
842
+ in `args.model_init_kwargs`.
843
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
844
+ args ([`SFTConfig`], *optional*, defaults to `None`):
845
+ Configuration for this trainer. If `None`, a default configuration is used.
846
+ data_collator (`DataCollator`, *optional*):
847
+ Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
848
+ Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
849
+ of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
850
+ tokenizer.
851
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
852
+ Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
853
+ [prompt-completion](#prompt-completion) type. The format of the samples can be either:
854
+
855
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
856
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
857
+ and content).
858
+
859
+ The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
860
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
861
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
862
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
863
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
864
+ with [`~transformers.AutoTokenizer.from_pretrained`].
865
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
866
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
867
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
868
+
869
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
870
+ method.
871
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
872
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
873
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
874
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
875
+ A tuple containing the optimizer class and keyword arguments to use.
876
+ Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
877
+
878
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
879
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
880
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
881
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
882
+ by this function will be reflected in the predictions received by `compute_metrics`.
883
+
884
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
885
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
886
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
887
+ formatting_func (`Optional[Callable]`):
888
+ Formatting function applied to the dataset before tokenization.
889
+
890
+ """
891
+ def __init__(
892
+ self,
893
+ model,
894
+ args = None,
895
+ data_collator = None,
896
+ train_dataset = None,
897
+ eval_dataset = None,
898
+ processing_class = None,
899
+ compute_loss_func = None,
900
+ compute_metrics = None,
901
+ callbacks = None,
902
+ optimizer_cls_and_kwargs = None,
903
+ preprocess_logits_for_metrics = None,
904
+ peft_config = None,
905
+ formatting_func = None,
906
+ **kwargs
907
+ ):
908
+ if args is None: args = UnslothSFTConfig()
909
+ use_bf16 = getattr(args, 'bf16', False)
910
+ use_fp16 = getattr(args, 'fp16', False)
911
+ force_float32 = False
912
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
913
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
914
+ force_float32 = True
915
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
916
+ dtype = getattr(model.config, 'torch_dtype', None)
917
+ if dtype is None: dtype = model.get_input_embeddings().dtype
918
+ from unsloth_zoo.utils import _get_dtype
919
+ dtype = _get_dtype(dtype)
920
+ float16 = dtype == torch.float16
921
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
922
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
923
+ if force_float32:
924
+ args.fp16 = False
925
+ args.bf16 = False
926
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
927
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
928
+ args.fp16 = float16
929
+ args.bf16 = not float16
930
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
931
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
932
+ args.eval_strategy = 'steps'
933
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
934
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
935
+ if ga_steps is not None and ga_steps > 1:
936
+ from transformers import __version__ as transformers_version
937
+ if Version(transformers_version) <= Version('4.45.2'):
938
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
939
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
940
+ if getattr(args, 'eval_strategy', 'no') != 'no':
941
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
942
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
943
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
944
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
945
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
946
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
947
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
948
+ if force_float32:
949
+ args.bf16_full_eval = False
950
+ args.fp16_full_eval = False
951
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
952
+ args.bf16_full_eval = True
953
+ args.fp16_full_eval = False
954
+ elif not bf16_full_eval and not fp16_full_eval:
955
+ args.bf16_full_eval = args.bf16
956
+ args.fp16_full_eval = args.fp16
957
+ _output_logits = False
958
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
959
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
960
+ if _output_logits:
961
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
962
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
963
+ pass
964
+ else:
965
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
966
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
967
+ if args_max_seq_length is None and model_max_seq_length is not None:
968
+ max_seq_length = model.max_seq_length
969
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
970
+ if model is not None and hasattr(model, 'for_training'):
971
+ model.for_training()
972
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
973
+ if 'processing_class' in locals():
974
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
975
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
976
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
977
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
978
+ if not isinstance(data_collator, UnslothVisionDataCollator):
979
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
980
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
981
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
982
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
983
+ else:
984
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
985
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
986
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
987
+ if not isinstance(data_collator, UnslothVisionDataCollator):
988
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
989
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
990
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
991
+ else:
992
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
993
+ other_metrics = []
994
+
995
+ from unsloth_zoo.logging_utils import PatchRLStatistics
996
+ PatchRLStatistics('sft_trainer', other_metrics)
997
+ IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
998
+ from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
999
+ from unsloth_zoo.training_utils import fix_zero_training_loss
1000
+ if 'tokenizer' not in locals(): tokenizer = processing_class
1001
+ fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1002
+ fix_zero_training_loss(model, tokenizer, train_dataset)
1003
+
1004
+ super().__init__(
1005
+ model = model,
1006
+ args = args,
1007
+ data_collator = data_collator,
1008
+ train_dataset = train_dataset,
1009
+ eval_dataset = eval_dataset,
1010
+ processing_class = processing_class,
1011
+ compute_loss_func = compute_loss_func,
1012
+ compute_metrics = compute_metrics,
1013
+ callbacks = callbacks,
1014
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1015
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1016
+ peft_config = peft_config,
1017
+ formatting_func = formatting_func,**kwargs)
1018
+ if hasattr(self, 'neftune_hook_handle'):
1019
+ self.neftune_hook_handle.remove()
1020
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1021
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1022
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1023
+ pass
1024
+
1025
+ pass
unsloth_compiled_cache/UnslothXPOTrainer.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.5.8
3
+ 2025.5.7
4
+ 4.51.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothXPOConfig(XPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`XPOTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
52
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
53
+ and the last alpha is used for the rest of the epochs.
54
+
55
+ """
56
+ vllm_sampling_params: Optional[Any] = field(
57
+ default = None,
58
+ metadata = {'help': 'vLLM SamplingParams'},
59
+ )
60
+ unsloth_num_chunks : Optional[int] = field(
61
+ default = -1,
62
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
63
+ )
64
+ def __init__(
65
+ self,
66
+ output_dir = None,
67
+ overwrite_output_dir = None,
68
+ do_train = False,
69
+ do_eval = False,
70
+ do_predict = False,
71
+ eval_strategy = 'no',
72
+ prediction_loss_only = False,
73
+ per_device_train_batch_size = 4,
74
+ per_device_eval_batch_size = 4,
75
+ per_gpu_train_batch_size = None,
76
+ per_gpu_eval_batch_size = None,
77
+ gradient_accumulation_steps = 2,
78
+ eval_accumulation_steps = 2,
79
+ eval_delay = 0,
80
+ torch_empty_cache_steps = 250,
81
+ learning_rate = 5e-05,
82
+ weight_decay = 0.01,
83
+ adam_beta1 = 0.9,
84
+ adam_beta2 = 0.999,
85
+ adam_epsilon = 1e-08,
86
+ max_grad_norm = 1.0,
87
+ num_train_epochs = 3.0,
88
+ max_steps = -1,
89
+ lr_scheduler_type = 'linear',
90
+ warmup_ratio = 0.1,
91
+ warmup_steps = 0,
92
+ log_level = 'passive',
93
+ log_level_replica = 'warning',
94
+ log_on_each_node = True,
95
+ logging_dir = None,
96
+ logging_strategy = 'steps',
97
+ logging_first_step = False,
98
+ logging_steps = 1,
99
+ logging_nan_inf_filter = False,
100
+ save_strategy = 'steps',
101
+ save_steps = 500,
102
+ save_total_limit = None,
103
+ save_safetensors = True,
104
+ save_on_each_node = False,
105
+ save_only_model = False,
106
+ restore_callback_states_from_checkpoint = False,
107
+ no_cuda = False,
108
+ use_cpu = False,
109
+ use_mps_device = False,
110
+ seed = 3407,
111
+ data_seed = 3407,
112
+ jit_mode_eval = False,
113
+ use_ipex = False,
114
+ bf16 = False,
115
+ fp16 = False,
116
+ fp16_opt_level = 'O1',
117
+ half_precision_backend = 'auto',
118
+ bf16_full_eval = False,
119
+ fp16_full_eval = False,
120
+ tf32 = None,
121
+ local_rank = -1,
122
+ ddp_backend = None,
123
+ tpu_num_cores = None,
124
+ tpu_metrics_debug = False,
125
+ debug = '',
126
+ dataloader_drop_last = False,
127
+ eval_steps = None,
128
+ dataloader_num_workers = 0,
129
+ dataloader_prefetch_factor = None,
130
+ past_index = -1,
131
+ run_name = None,
132
+ disable_tqdm = None,
133
+ remove_unused_columns = True,
134
+ label_names = None,
135
+ load_best_model_at_end = False,
136
+ metric_for_best_model = None,
137
+ greater_is_better = None,
138
+ ignore_data_skip = False,
139
+ fsdp = '',
140
+ fsdp_min_num_params = 0,
141
+ fsdp_config = None,
142
+ tp_size = 0,
143
+ fsdp_transformer_layer_cls_to_wrap = None,
144
+ accelerator_config = None,
145
+ deepspeed = None,
146
+ label_smoothing_factor = 0.0,
147
+ optim = 'adamw_8bit',
148
+ optim_args = None,
149
+ adafactor = False,
150
+ group_by_length = False,
151
+ length_column_name = 'length',
152
+ report_to = None,
153
+ ddp_find_unused_parameters = None,
154
+ ddp_bucket_cap_mb = None,
155
+ ddp_broadcast_buffers = None,
156
+ dataloader_pin_memory = True,
157
+ dataloader_persistent_workers = False,
158
+ skip_memory_metrics = True,
159
+ use_legacy_prediction_loop = False,
160
+ push_to_hub = False,
161
+ resume_from_checkpoint = None,
162
+ hub_model_id = None,
163
+ hub_strategy = 'every_save',
164
+ hub_token = None,
165
+ hub_private_repo = None,
166
+ hub_always_push = False,
167
+ gradient_checkpointing = False,
168
+ gradient_checkpointing_kwargs = None,
169
+ include_inputs_for_metrics = False,
170
+ eval_do_concat_batches = True,
171
+ fp16_backend = 'auto',
172
+ push_to_hub_model_id = None,
173
+ push_to_hub_organization = None,
174
+ push_to_hub_token = None,
175
+ mp_parameters = '',
176
+ auto_find_batch_size = False,
177
+ full_determinism = False,
178
+ torchdynamo = None,
179
+ ray_scope = 'last',
180
+ ddp_timeout = 1800,
181
+ torch_compile = False,
182
+ torch_compile_backend = None,
183
+ torch_compile_mode = None,
184
+ include_tokens_per_second = False,
185
+ include_num_input_tokens_seen = False,
186
+ neftune_noise_alpha = None,
187
+ optim_target_modules = None,
188
+ batch_eval_metrics = False,
189
+ eval_on_start = False,
190
+ use_liger_kernel = False,
191
+ eval_use_gather_object = False,
192
+ average_tokens_across_devices = False,
193
+ reward_model_path = None,
194
+ judge = None,
195
+ max_new_tokens = 64,
196
+ max_length = 512,
197
+ temperature = 0.9,
198
+ missing_eos_penalty = None,
199
+ loss_type = 'sigmoid',
200
+ dataset_num_proc = None,
201
+ disable_dropout = True,
202
+ use_vllm = False,
203
+ ds3_gather_for_generation = True,
204
+ vllm_sampling_params = None,
205
+ unsloth_num_chunks = -1,
206
+ **kwargs,
207
+ ):
208
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
209
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
210
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
211
+ output_dir = 'unsloth_training_checkpoints'
212
+ save_strategy = 'no'
213
+ if dataset_num_proc is None:
214
+ from multiprocessing import cpu_count
215
+ dataset_num_proc = cpu_count()
216
+
217
+ super().__init__(
218
+ output_dir = output_dir,
219
+ overwrite_output_dir = overwrite_output_dir,
220
+ do_train = do_train,
221
+ do_eval = do_eval,
222
+ do_predict = do_predict,
223
+ eval_strategy = eval_strategy,
224
+ prediction_loss_only = prediction_loss_only,
225
+ per_device_train_batch_size = per_device_train_batch_size,
226
+ per_device_eval_batch_size = per_device_eval_batch_size,
227
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
228
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
229
+ gradient_accumulation_steps = gradient_accumulation_steps,
230
+ eval_accumulation_steps = eval_accumulation_steps,
231
+ eval_delay = eval_delay,
232
+ torch_empty_cache_steps = torch_empty_cache_steps,
233
+ learning_rate = learning_rate,
234
+ weight_decay = weight_decay,
235
+ adam_beta1 = adam_beta1,
236
+ adam_beta2 = adam_beta2,
237
+ adam_epsilon = adam_epsilon,
238
+ max_grad_norm = max_grad_norm,
239
+ num_train_epochs = num_train_epochs,
240
+ max_steps = max_steps,
241
+ lr_scheduler_type = lr_scheduler_type,
242
+ warmup_ratio = warmup_ratio,
243
+ warmup_steps = warmup_steps,
244
+ log_level = log_level,
245
+ log_level_replica = log_level_replica,
246
+ log_on_each_node = log_on_each_node,
247
+ logging_dir = logging_dir,
248
+ logging_strategy = logging_strategy,
249
+ logging_first_step = logging_first_step,
250
+ logging_steps = logging_steps,
251
+ logging_nan_inf_filter = logging_nan_inf_filter,
252
+ save_strategy = save_strategy,
253
+ save_steps = save_steps,
254
+ save_total_limit = save_total_limit,
255
+ save_safetensors = save_safetensors,
256
+ save_on_each_node = save_on_each_node,
257
+ save_only_model = save_only_model,
258
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
259
+ no_cuda = no_cuda,
260
+ use_cpu = use_cpu,
261
+ use_mps_device = use_mps_device,
262
+ seed = seed,
263
+ data_seed = data_seed,
264
+ jit_mode_eval = jit_mode_eval,
265
+ use_ipex = use_ipex,
266
+ bf16 = bf16,
267
+ fp16 = fp16,
268
+ fp16_opt_level = fp16_opt_level,
269
+ half_precision_backend = half_precision_backend,
270
+ bf16_full_eval = bf16_full_eval,
271
+ fp16_full_eval = fp16_full_eval,
272
+ tf32 = tf32,
273
+ local_rank = local_rank,
274
+ ddp_backend = ddp_backend,
275
+ tpu_num_cores = tpu_num_cores,
276
+ tpu_metrics_debug = tpu_metrics_debug,
277
+ debug = debug,
278
+ dataloader_drop_last = dataloader_drop_last,
279
+ eval_steps = eval_steps,
280
+ dataloader_num_workers = dataloader_num_workers,
281
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
282
+ past_index = past_index,
283
+ run_name = run_name,
284
+ disable_tqdm = disable_tqdm,
285
+ remove_unused_columns = remove_unused_columns,
286
+ label_names = label_names,
287
+ load_best_model_at_end = load_best_model_at_end,
288
+ metric_for_best_model = metric_for_best_model,
289
+ greater_is_better = greater_is_better,
290
+ ignore_data_skip = ignore_data_skip,
291
+ fsdp = fsdp,
292
+ fsdp_min_num_params = fsdp_min_num_params,
293
+ fsdp_config = fsdp_config,
294
+ tp_size = tp_size,
295
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
296
+ accelerator_config = accelerator_config,
297
+ deepspeed = deepspeed,
298
+ label_smoothing_factor = label_smoothing_factor,
299
+ optim = optim,
300
+ optim_args = optim_args,
301
+ adafactor = adafactor,
302
+ group_by_length = group_by_length,
303
+ length_column_name = length_column_name,
304
+ report_to = report_to,
305
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
306
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
307
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
308
+ dataloader_pin_memory = dataloader_pin_memory,
309
+ dataloader_persistent_workers = dataloader_persistent_workers,
310
+ skip_memory_metrics = skip_memory_metrics,
311
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
312
+ push_to_hub = push_to_hub,
313
+ resume_from_checkpoint = resume_from_checkpoint,
314
+ hub_model_id = hub_model_id,
315
+ hub_strategy = hub_strategy,
316
+ hub_token = hub_token,
317
+ hub_private_repo = hub_private_repo,
318
+ hub_always_push = hub_always_push,
319
+ gradient_checkpointing = gradient_checkpointing,
320
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
321
+ include_inputs_for_metrics = include_inputs_for_metrics,
322
+ eval_do_concat_batches = eval_do_concat_batches,
323
+ fp16_backend = fp16_backend,
324
+ push_to_hub_model_id = push_to_hub_model_id,
325
+ push_to_hub_organization = push_to_hub_organization,
326
+ push_to_hub_token = push_to_hub_token,
327
+ mp_parameters = mp_parameters,
328
+ auto_find_batch_size = auto_find_batch_size,
329
+ full_determinism = full_determinism,
330
+ torchdynamo = torchdynamo,
331
+ ray_scope = ray_scope,
332
+ ddp_timeout = ddp_timeout,
333
+ torch_compile = torch_compile,
334
+ torch_compile_backend = torch_compile_backend,
335
+ torch_compile_mode = torch_compile_mode,
336
+ include_tokens_per_second = include_tokens_per_second,
337
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
338
+ neftune_noise_alpha = neftune_noise_alpha,
339
+ optim_target_modules = optim_target_modules,
340
+ batch_eval_metrics = batch_eval_metrics,
341
+ eval_on_start = eval_on_start,
342
+ use_liger_kernel = use_liger_kernel,
343
+ eval_use_gather_object = eval_use_gather_object,
344
+ average_tokens_across_devices = average_tokens_across_devices,
345
+ reward_model_path = reward_model_path,
346
+ judge = judge,
347
+ max_new_tokens = max_new_tokens,
348
+ max_length = max_length,
349
+ temperature = temperature,
350
+ missing_eos_penalty = missing_eos_penalty,
351
+ loss_type = loss_type,
352
+ dataset_num_proc = dataset_num_proc,
353
+ disable_dropout = disable_dropout,
354
+ use_vllm = use_vllm,
355
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
356
+ self.vllm_sampling_params = vllm_sampling_params
357
+ self.unsloth_num_chunks = unsloth_num_chunks
358
+ pass
359
+
360
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
361
+ r""""""
362
+
363
+ _tag_names = ["trl", "xpo"]
364
+
365
+ def __init__(
366
+ self,
367
+ model: Union[PreTrainedModel, nn.Module] = None,
368
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
369
+ reward_model: Optional[nn.Module] = None,
370
+ judge: Optional[BasePairwiseJudge] = None,
371
+ args: Optional[XPOConfig] = None,
372
+ data_collator: Optional[Callable] = None,
373
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
374
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
375
+ processing_class: Optional[
376
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
377
+ ] = None,
378
+ peft_config: Optional[dict] = None,
379
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
380
+ callbacks: Optional[list[TrainerCallback]] = None,
381
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
382
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
383
+ ) -> None:
384
+ super().__init__(
385
+ model=model,
386
+ ref_model=ref_model,
387
+ judge=judge,
388
+ reward_model=reward_model,
389
+ args=args,
390
+ data_collator=data_collator,
391
+ train_dataset=train_dataset,
392
+ eval_dataset=eval_dataset,
393
+ processing_class=processing_class,
394
+ reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
395
+ peft_config=peft_config,
396
+ compute_metrics=compute_metrics,
397
+ callbacks=callbacks,
398
+ optimizers=optimizers,
399
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
400
+ )
401
+
402
+ self._alpha = self.args.alpha
403
+
404
+ # Overwrite the stats dictionary to include XPO specific statistics
405
+ self.stats = {
406
+ # Remove "non_score_reward", "rlhf_reward", "scores"
407
+ # Add "loss/dpo", "loss/xpo"
408
+ "loss/dpo": [],
409
+ "loss/xpo": [],
410
+ "objective/kl": [],
411
+ "objective/entropy": [],
412
+ "rewards/chosen": [],
413
+ "rewards/rejected": [],
414
+ "rewards/accuracies": [],
415
+ "rewards/margins": [],
416
+ "logps/chosen": [],
417
+ "logps/rejected": [],
418
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
419
+ "val/model_contain_eos_token": [],
420
+ "val/ref_contain_eos_token": [],
421
+ "alpha": [],
422
+ "beta": [],
423
+ }
424
+ if self.reward_model is not None:
425
+ # Replace "scores" by "model_scores" and "ref_scores"
426
+ self.stats["objective/model_scores"] = []
427
+ self.stats["objective/ref_scores"] = []
428
+ self.stats["objective/scores_margin"] = []
429
+
430
+ @property
431
+ def alpha(self):
432
+ if isinstance(self._alpha, list):
433
+ epoch = self.state.epoch
434
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
435
+ else:
436
+ return self._alpha
437
+
438
+ def _generate_completions(self, prompts, model):
439
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
440
+ model_output = unwrapped_model.generate(
441
+ input_ids=prompts["input_ids"],
442
+ attention_mask=prompts["attention_mask"],
443
+ generation_config=self.generation_config,
444
+ )
445
+
446
+ ref_model = model if self.ref_model is None else self.ref_model
447
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
448
+ ref_output = unwrapped_ref_model.generate(
449
+ input_ids=prompts["input_ids"],
450
+ attention_mask=prompts["attention_mask"],
451
+ generation_config=self.generation_config,
452
+ )
453
+
454
+ return model_output, ref_output
455
+
456
+ def _process_completions(self, model_output, ref_output, prompts):
457
+ context_length = prompts["input_ids"].shape[1]
458
+
459
+ # Process model completions
460
+ model_completion_ids = model_output[:, context_length:]
461
+ model_completion_ids, model_completion_mask = truncate_right(
462
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
463
+ )
464
+ model_data = {
465
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
466
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
467
+ "raw": prompts["raw"],
468
+ }
469
+
470
+ # Process reference model completions
471
+ ref_completion_ids = ref_output[:, context_length:]
472
+ ref_completion_ids, ref_completion_mask = truncate_right(
473
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
474
+ )
475
+ ref_data = {
476
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
477
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
478
+ "raw": prompts["raw"],
479
+ }
480
+
481
+ return model_data, ref_data
482
+
483
+ def _compute_rewards(self, model_data, ref_data, context_length):
484
+ with torch.no_grad():
485
+ _, model_scores, _ = get_reward(
486
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
487
+ )
488
+ _, ref_scores, _ = get_reward(
489
+ self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
490
+ )
491
+
492
+ # Apply EOS penalty if needed
493
+ if self.args.missing_eos_penalty is not None:
494
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
495
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
496
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
497
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
498
+
499
+ return model_scores, ref_scores
500
+
501
+ def _compute_judge(self, model_data, ref_data, context_length):
502
+ prompts = model_data["raw"]
503
+ model_data_completions = self.processing_class.batch_decode(
504
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
505
+ )
506
+ model_data_completions = [completion.strip() for completion in model_data_completions]
507
+
508
+ ref_data_completions = self.processing_class.batch_decode(
509
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
510
+ )
511
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
512
+
513
+ if is_conversational({"prompt": prompts[0]}):
514
+ model_data_completions = [
515
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
516
+ ]
517
+ environment = jinja2.Environment()
518
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
519
+ prompts = [template.render(messages=message) for message in prompts]
520
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
521
+
522
+ ref_data_completions = [
523
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
524
+ ]
525
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
526
+
527
+ ranks_of_first_completion = self.judge.judge(
528
+ prompts,
529
+ list(zip(model_data_completions, ref_data_completions)),
530
+ )
531
+ # convert ranks to a True/False mask:
532
+ # when rank == 0, it means the first completion is the best
533
+ # when rank == 1, it means the second completion is the best
534
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
535
+
536
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
537
+ def compute_logprobs_for_data(m, data):
538
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
539
+ logits = output.logits[:, context_length - 1 : -1]
540
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
541
+ return token_logprobs
542
+
543
+ # Compute logprobs for model completions
544
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
545
+ # Compute logprobs for model on reference completions (for XPO loss)
546
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
547
+
548
+ # Compute logprobs for reference model completions
549
+ with torch.no_grad():
550
+ if self.ref_model is None:
551
+ with model.disable_adapter():
552
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
553
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
554
+ else:
555
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
556
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
557
+
558
+ # Mask padding tokens
559
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
560
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
561
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
562
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
563
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
564
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
565
+
566
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
567
+
568
+ def _compute_losses(
569
+ self,
570
+ model_logprobs_model_data,
571
+ model_logprobs_ref_data,
572
+ ref_logprobs_ref_data,
573
+ ref_logprobs_model_data,
574
+ chosen_mask,
575
+ ):
576
+ # Compute log probs
577
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
578
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
579
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
580
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
581
+
582
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
583
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
584
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
585
+
586
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
587
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
588
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
589
+
590
+ # Compute logits as the difference between chosen and rejected log ratios
591
+ logits = chosen_log_ratios - rejected_log_ratios
592
+
593
+ if self.args.loss_type == "sigmoid":
594
+ dpo_losses = -F.logsigmoid(self.beta * logits)
595
+ elif self.args.loss_type == "ipo":
596
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
597
+ else:
598
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
599
+
600
+ # Compute XPO specific loss
601
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
602
+
603
+ # Total loss
604
+ loss = (dpo_losses + xpo_losses).mean()
605
+
606
+ return loss, dpo_losses, xpo_losses
607
+
608
+ def _log_statistics(
609
+ self,
610
+ model_data,
611
+ ref_data,
612
+ model_logprobs_model_data,
613
+ model_logprobs_ref_data,
614
+ ref_logprobs_ref_data,
615
+ ref_logprobs_model_data,
616
+ chosen_mask,
617
+ dpo_losses,
618
+ xpo_losses,
619
+ context_length,
620
+ model_scores=None,
621
+ ref_scores=None,
622
+ ):
623
+ # Helper function to gather and compute mean
624
+ def gather_mean(tensor):
625
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
626
+
627
+ # Log losses
628
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
629
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
630
+
631
+ # Log scores
632
+ if self.reward_model is not None:
633
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
634
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
635
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
636
+
637
+ # Log logprobs
638
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
639
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
640
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
641
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
642
+
643
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
644
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
645
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
646
+
647
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
648
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
649
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
650
+
651
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
652
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
653
+
654
+ # Log rewards
655
+ # Compute various statistics
656
+ chosen_rewards = chosen_log_ratios * self.beta
657
+ rejected_rewards = rejected_log_ratios * self.beta
658
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
659
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
660
+
661
+ # Calculate KL divergence for model and ref data
662
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
663
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
664
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
665
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
666
+
667
+ # Calculate entropy for model and ref data
668
+ entropy_model_data = -model_logprobs_model_data.sum(1)
669
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
670
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
671
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
672
+
673
+ # Calculate margins
674
+ margin = chosen_rewards - rejected_rewards
675
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
676
+
677
+ # Calculate accuracy
678
+ accuracy = (margin > 0).float()
679
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
680
+
681
+ # Log EOS token statistics
682
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
683
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
684
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
685
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
686
+
687
+ # Log alpha and beta
688
+ self.stats["alpha"].append(self.alpha)
689
+ self.stats["beta"].append(self.beta)
690
+
691
+ def training_step(
692
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
693
+ ) -> torch.Tensor:
694
+ model.train()
695
+
696
+ # Apply chat template and tokenize the input
697
+ batch_size = len(next(iter(inputs.values())))
698
+ prompts = inputs["prompt"]
699
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
700
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
701
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
702
+ inputs = self.data_collator(inputs)
703
+
704
+ # need the prompt_ only
705
+ inputs = self._prepare_inputs(inputs)
706
+ context_length = inputs["prompt_input_ids"].shape[1]
707
+ prompts = {
708
+ "input_ids": inputs["prompt_input_ids"],
709
+ "attention_mask": inputs["prompt_attention_mask"],
710
+ "raw": prompts,
711
+ }
712
+ del inputs
713
+
714
+ # Sample completions from both the model and the reference model
715
+ model_output, ref_output = self._generate_completions(prompts, model)
716
+
717
+ # Process model completions
718
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
719
+
720
+ # Compute rewards
721
+ if self.reward_model is not None:
722
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
723
+ chosen_mask = model_scores >= ref_scores
724
+ else:
725
+ model_scores, ref_scores = None, None
726
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
727
+
728
+ # Compute logprobs
729
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
730
+ self._compute_logprobs(model, model_data, ref_data, context_length)
731
+ )
732
+
733
+ # Compute loss
734
+ loss, dpo_losses, xpo_losses = self._compute_losses(
735
+ model_logprobs_model_data,
736
+ model_logprobs_ref_data,
737
+ ref_logprobs_ref_data,
738
+ ref_logprobs_model_data,
739
+ chosen_mask,
740
+ )
741
+
742
+ # Log everything
743
+ self._log_statistics(
744
+ model_data,
745
+ ref_data,
746
+ model_logprobs_model_data.detach(),
747
+ model_logprobs_ref_data.detach(),
748
+ ref_logprobs_ref_data,
749
+ ref_logprobs_model_data,
750
+ chosen_mask,
751
+ dpo_losses.detach(),
752
+ xpo_losses.detach(),
753
+ context_length,
754
+ model_scores,
755
+ ref_scores,
756
+ )
757
+
758
+ if (
759
+ self.args.torch_empty_cache_steps is not None
760
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
761
+ ):
762
+ empty_cache()
763
+
764
+ kwargs = {}
765
+ # For LOMO optimizers you need to explicitly use the learning rate
766
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
767
+ kwargs["learning_rate"] = self._get_learning_rate()
768
+
769
+ if self.args.n_gpu > 1:
770
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
771
+
772
+ if self.use_apex:
773
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
774
+ scaled_loss.backward()
775
+ else:
776
+ self.accelerator.backward(loss, **kwargs)
777
+
778
+ return loss.detach() / self.args.gradient_accumulation_steps
779
+
780
+ def create_model_card(
781
+ self,
782
+ model_name: Optional[str] = None,
783
+ dataset_name: Optional[str] = None,
784
+ tags: Union[str, list[str], None] = None,
785
+ ):
786
+ """
787
+ Creates a draft of a model card using the information available to the `Trainer`.
788
+
789
+ Args:
790
+ model_name (`str` or `None`, *optional*, defaults to `None`):
791
+ Name of the model.
792
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
793
+ Name of the dataset used for training.
794
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
795
+ Tags to be associated with the model card.
796
+ """
797
+ if not self.is_world_process_zero():
798
+ return
799
+
800
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
801
+ base_model = self.model.config._name_or_path
802
+ else:
803
+ base_model = None
804
+
805
+ tags = tags or []
806
+ if isinstance(tags, str):
807
+ tags = [tags]
808
+
809
+ if hasattr(self.model.config, "unsloth_version"):
810
+ tags.append("unsloth")
811
+
812
+ citation = textwrap.dedent("""\
813
+ @article{jung2024binary,
814
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
815
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
816
+ year = 2024,
817
+ eprint = {arXiv:2405.21046}
818
+ }""")
819
+
820
+ model_card = generate_model_card(
821
+ base_model=base_model,
822
+ model_name=model_name,
823
+ hub_model_id=self.hub_model_id,
824
+ dataset_name=dataset_name,
825
+ tags=tags,
826
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
827
+ comet_url=get_comet_experiment_url(),
828
+ trainer_name="XPO",
829
+ trainer_citation=citation,
830
+ paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
831
+ paper_id="2405.21046",
832
+ )
833
+
834
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
835
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
836
+ """
837
+
838
+ Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
839
+
840
+ Args:
841
+ model (`transformers.PreTrainedModel`):
842
+ The model to train, preferably an `AutoModelForCausalLM`.
843
+ ref_model (`PreTrainedModelWrapper`):
844
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
845
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
846
+ reward_model (`transformers.PreTrainedModel`):
847
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
848
+ judge (`BasePairwiseJudge`):
849
+ The judge to use for pairwise comparison of model completions.
850
+ args (`XPOConfig`):
851
+ The XPO config arguments to use for training.
852
+ data_collator (`transformers.DataCollator`):
853
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
854
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
855
+ train_dataset (`datasets.Dataset`):
856
+ The dataset to use for training.
857
+ eval_dataset (`datasets.Dataset`):
858
+ The dataset to use for evaluation.
859
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
860
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
861
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
862
+ reuse the fine-tuned model.
863
+ peft_config (`dict`):
864
+ The peft config to use for training.
865
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
866
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
867
+ a dictionary string to metric values.
868
+ callbacks (`list[transformers.TrainerCallback]`):
869
+ The callbacks to use for training.
870
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
871
+ The optimizer and scheduler to use for training.
872
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
873
+ The function to use to preprocess the logits before computing the metrics.
874
+
875
+ """
876
+ def __init__(
877
+ self,
878
+ model = None,
879
+ ref_model = None,
880
+ reward_model = None,
881
+ judge = None,
882
+ args = None,
883
+ data_collator = None,
884
+ train_dataset = None,
885
+ eval_dataset = None,
886
+ processing_class = None,
887
+ peft_config = None,
888
+ compute_metrics = None,
889
+ callbacks = None,
890
+ preprocess_logits_for_metrics = None,
891
+ **kwargs
892
+ ):
893
+ if args is None: args = UnslothXPOConfig()
894
+ use_bf16 = getattr(args, 'bf16', False)
895
+ use_fp16 = getattr(args, 'fp16', False)
896
+ force_float32 = False
897
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
898
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
899
+ force_float32 = True
900
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
901
+ dtype = getattr(model.config, 'torch_dtype', None)
902
+ if dtype is None: dtype = model.get_input_embeddings().dtype
903
+ from unsloth_zoo.utils import _get_dtype
904
+ dtype = _get_dtype(dtype)
905
+ float16 = dtype == torch.float16
906
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
907
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
908
+ if force_float32:
909
+ args.fp16 = False
910
+ args.bf16 = False
911
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
912
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
913
+ args.fp16 = float16
914
+ args.bf16 = not float16
915
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
916
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
917
+ args.eval_strategy = 'steps'
918
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
919
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
920
+ if ga_steps is not None and ga_steps > 1:
921
+ from transformers import __version__ as transformers_version
922
+ if Version(transformers_version) <= Version('4.45.2'):
923
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
924
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
925
+ if getattr(args, 'eval_strategy', 'no') != 'no':
926
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
927
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
928
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
929
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
930
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
931
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
932
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
933
+ if force_float32:
934
+ args.bf16_full_eval = False
935
+ args.fp16_full_eval = False
936
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
937
+ args.bf16_full_eval = True
938
+ args.fp16_full_eval = False
939
+ elif not bf16_full_eval and not fp16_full_eval:
940
+ args.bf16_full_eval = args.bf16
941
+ args.fp16_full_eval = args.fp16
942
+ _output_logits = False
943
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
944
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
945
+ if _output_logits:
946
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
947
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
948
+ pass
949
+ else:
950
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
951
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
952
+ if args_max_seq_length is None and model_max_seq_length is not None:
953
+ max_seq_length = model.max_seq_length
954
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
955
+ if model is not None and hasattr(model, 'for_training'):
956
+ model.for_training()
957
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
958
+ if 'processing_class' in locals():
959
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
960
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
961
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
962
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
963
+ if not isinstance(data_collator, UnslothVisionDataCollator):
964
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
965
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
966
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
967
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
968
+ else:
969
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
970
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
971
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
972
+ if not isinstance(data_collator, UnslothVisionDataCollator):
973
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
974
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
975
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
976
+ else:
977
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
978
+ other_metrics = []
979
+
980
+ from unsloth_zoo.logging_utils import PatchRLStatistics
981
+ PatchRLStatistics('xpo_trainer', other_metrics)
982
+
983
+ super().__init__(
984
+ model = model,
985
+ ref_model = ref_model,
986
+ reward_model = reward_model,
987
+ judge = judge,
988
+ args = args,
989
+ data_collator = data_collator,
990
+ train_dataset = train_dataset,
991
+ eval_dataset = eval_dataset,
992
+ processing_class = processing_class,
993
+ peft_config = peft_config,
994
+ compute_metrics = compute_metrics,
995
+ callbacks = callbacks,
996
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
997
+ if hasattr(self, 'neftune_hook_handle'):
998
+ self.neftune_hook_handle.remove()
999
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1000
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1001
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1002
+ pass
1003
+
1004
+ pass
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-312.pyc ADDED
Binary file (31.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc ADDED
Binary file (84.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc ADDED
Binary file (69.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-312.pyc ADDED
Binary file (43.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc ADDED
Binary file (97.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc ADDED
Binary file (34.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc ADDED
Binary file (70.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc ADDED
Binary file (80.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc ADDED
Binary file (41.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc ADDED
Binary file (68.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc ADDED
Binary file (60.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc ADDED
Binary file (59.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc ADDED
Binary file (32.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc ADDED
Binary file (51.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc ADDED
Binary file (35 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc ADDED
Binary file (43.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc ADDED
Binary file (44 kB). View file