alimosavian Ali Mosavian winglian commited on
Commit
1e1921b
·
unverified ·
1 Parent(s): 1634ac8

FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584)

Browse files

* FIX: TRL trainer preprocessing step was running in one process

* FIX: max_length and max_prompt_length was not being sent to ORPOTrainer

* FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour

* FIX: Removed change from a different PR

* FIX: Black fix

* explicitly set max prompt len for orpo config

---------

Co-authored-by: Ali Mosavian <[email protected]>
Co-authored-by: Wing Lian <[email protected]>

src/axolotl/core/trainer_builder.py CHANGED
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1526
  if self.cfg.rl == "orpo":
1527
  training_args_cls = ORPOConfig
1528
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
 
 
 
1529
 
1530
  training_args = training_args_cls(
1531
  per_device_train_batch_size=self.cfg.micro_batch_size,
 
1526
  if self.cfg.rl == "orpo":
1527
  training_args_cls = ORPOConfig
1528
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1529
+ training_args_kwargs["max_length"] = self.cfg.sequence_len
1530
+ if self.cfg.max_prompt_len:
1531
+ training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1532
 
1533
  training_args = training_args_cls(
1534
  per_device_train_batch_size=self.cfg.micro_batch_size,
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -517,6 +517,9 @@ class AxolotlInputConfig(
517
 
518
  sequence_len: int = Field(default=512)
519
  min_sample_len: Optional[int] = None
 
 
 
520
  sample_packing: Optional[bool] = None
521
  eval_sample_packing: Optional[bool] = None
522
  pad_to_sequence_len: Optional[bool] = None
 
517
 
518
  sequence_len: int = Field(default=512)
519
  min_sample_len: Optional[int] = None
520
+ max_prompt_len: int = Field(
521
+ default=512, metadata={"help": "maximum prompt length for RL training"}
522
+ )
523
  sample_packing: Optional[bool] = None
524
  eval_sample_packing: Optional[bool] = None
525
  pad_to_sequence_len: Optional[bool] = None