ORPO (#1419)
Browse files* orpo trainer
* rl handling for orpo
* support for remove_unused_columns
* orpo fixes
* fix loader for orpo
* chore: lint
* fix default for remove_unused_columns
* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora
* better handling of system message for orpo
* revert system prompt changes for chat templtes
* no need for else condition
* split dataset parsing into it's own component
- docs/rlhf.md +15 -0
- src/axolotl/cli/preprocess.py +1 -1
- src/axolotl/cli/train.py +1 -1
- src/axolotl/core/trainer_builder.py +142 -1
- src/axolotl/prompt_strategies/base.py +20 -0
- src/axolotl/prompt_strategies/dpo/__init__.py +3 -15
- src/axolotl/prompt_strategies/orpo/__init__.py +9 -0
- src/axolotl/prompt_strategies/orpo/chat_template.py +187 -0
- src/axolotl/train.py +1 -1
- src/axolotl/utils/chat_templates.py +1 -1
- src/axolotl/utils/config/__init__.py +5 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +5 -0
- src/axolotl/utils/freeze.py +5 -3
- tests/test_prompt_tokenizers.py +56 -1
docs/rlhf.md
CHANGED
|
@@ -34,6 +34,21 @@ datasets:
|
|
| 34 |
rl: ipo
|
| 35 |
```
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
#### Using local dataset files
|
| 38 |
```yaml
|
| 39 |
datasets:
|
|
|
|
| 34 |
rl: ipo
|
| 35 |
```
|
| 36 |
|
| 37 |
+
#### ORPO
|
| 38 |
+
|
| 39 |
+
Paper: https://arxiv.org/abs/2403.07691
|
| 40 |
+
|
| 41 |
+
```yaml
|
| 42 |
+
rl: orpo
|
| 43 |
+
orpo_alpha: 0.1
|
| 44 |
+
remove_unused_columns: false
|
| 45 |
+
|
| 46 |
+
chat_template: chatml
|
| 47 |
+
datasets:
|
| 48 |
+
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
| 49 |
+
type: orpo.chat_template
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
#### Using local dataset files
|
| 53 |
```yaml
|
| 54 |
datasets:
|
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
| 54 |
LOG.warning(msg)
|
| 55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 56 |
|
| 57 |
-
if parsed_cfg.rl:
|
| 58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 59 |
else:
|
| 60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
|
| 54 |
LOG.warning(msg)
|
| 55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
| 56 |
|
| 57 |
+
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
| 58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
| 59 |
else:
|
| 60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
src/axolotl/cli/train.py
CHANGED
|
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
| 47 |
else:
|
| 48 |
register_chatml_template()
|
| 49 |
|
| 50 |
-
if cfg.rl:
|
| 51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 52 |
else:
|
| 53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
| 47 |
else:
|
| 48 |
register_chatml_template()
|
| 49 |
|
| 50 |
+
if cfg.rl and cfg.rl != "orpo":
|
| 51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 52 |
else:
|
| 53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -11,10 +11,11 @@ import math
|
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
from abc import abstractmethod
|
|
|
|
| 14 |
from dataclasses import dataclass, field
|
| 15 |
from functools import wraps
|
| 16 |
from pathlib import Path
|
| 17 |
-
from typing import List, Optional, Type, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import transformers
|
|
@@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 200 |
default=False,
|
| 201 |
metadata={"help": "whether this is a qlora training"},
|
| 202 |
)
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
class AxolotlTrainer(Trainer):
|
|
@@ -223,6 +227,9 @@ class AxolotlTrainer(Trainer):
|
|
| 223 |
self.eval_data_collator = eval_data_collator
|
| 224 |
super().__init__(*_args, **kwargs)
|
| 225 |
self.train_data_collator = self.data_collator
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def create_optimizer(self):
|
| 228 |
if self.args.loraplus_lr_ratio is None:
|
|
@@ -465,8 +472,112 @@ class AxolotlTrainer(Trainer):
|
|
| 465 |
# outputs = model(**inputs)
|
| 466 |
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
| 467 |
# return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
|
| 468 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
@wraps(Trainer.push_to_hub)
|
| 471 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 472 |
"""
|
|
@@ -527,6 +638,28 @@ class AxolotlTrainer(Trainer):
|
|
| 527 |
|
| 528 |
return res
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 532 |
"""
|
|
@@ -903,6 +1036,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 903 |
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
| 904 |
training_arguments_kwargs["dataloader_drop_last"] = True
|
| 905 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
| 907 |
# no eval set, so don't eval
|
| 908 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
@@ -1070,6 +1208,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 1070 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
| 1071 |
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
| 1072 |
|
|
|
|
|
|
|
|
|
|
| 1073 |
if self.cfg.neftune_noise_alpha is not None:
|
| 1074 |
training_arguments_kwargs[
|
| 1075 |
"neftune_noise_alpha"
|
|
|
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
from abc import abstractmethod
|
| 14 |
+
from collections import defaultdict
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
from functools import wraps
|
| 17 |
from pathlib import Path
|
| 18 |
+
from typing import Dict, List, Literal, Optional, Type, Union
|
| 19 |
|
| 20 |
import torch
|
| 21 |
import transformers
|
|
|
|
| 201 |
default=False,
|
| 202 |
metadata={"help": "whether this is a qlora training"},
|
| 203 |
)
|
| 204 |
+
orpo_alpha: Optional[float] = field(
|
| 205 |
+
default=None,
|
| 206 |
+
)
|
| 207 |
|
| 208 |
|
| 209 |
class AxolotlTrainer(Trainer):
|
|
|
|
| 227 |
self.eval_data_collator = eval_data_collator
|
| 228 |
super().__init__(*_args, **kwargs)
|
| 229 |
self.train_data_collator = self.data_collator
|
| 230 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 231 |
+
if self.args.orpo_alpha:
|
| 232 |
+
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
| 233 |
|
| 234 |
def create_optimizer(self):
|
| 235 |
if self.args.loraplus_lr_ratio is None:
|
|
|
|
| 472 |
# outputs = model(**inputs)
|
| 473 |
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
| 474 |
# return (loss, outputs) if return_outputs else loss
|
| 475 |
+
if self.args.orpo_alpha:
|
| 476 |
+
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
| 477 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 478 |
|
| 479 |
+
def orpo_compute_custom_loss(self, logits, labels):
|
| 480 |
+
logits = logits.contiguous()
|
| 481 |
+
loss = 0.0
|
| 482 |
+
|
| 483 |
+
if labels is not None:
|
| 484 |
+
# move labels to correct device to enable model parallelism
|
| 485 |
+
labels = labels.to(logits.device)
|
| 486 |
+
# Shift so that tokens < n predict n
|
| 487 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 488 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 489 |
+
|
| 490 |
+
# Flatten the tokens
|
| 491 |
+
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
| 492 |
+
dim=-1
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
return loss
|
| 496 |
+
|
| 497 |
+
def orpo_compute_logps(
|
| 498 |
+
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
| 499 |
+
):
|
| 500 |
+
# Get the shape of chosen_attention_mask[:, :-1]
|
| 501 |
+
chosen_shape = chosen_attention_mask[:, :-1].shape
|
| 502 |
+
|
| 503 |
+
# Calculate the padding size
|
| 504 |
+
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
| 505 |
+
|
| 506 |
+
# Pad prompt_attention_mask with zeros to match the desired shape
|
| 507 |
+
prompt_attention_mask_padded = torch.nn.functional.pad(
|
| 508 |
+
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Perform the subtraction operation
|
| 512 |
+
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
| 513 |
+
|
| 514 |
+
per_token_logps = torch.gather(
|
| 515 |
+
logits[:, :-1, :].log_softmax(-1),
|
| 516 |
+
dim=2,
|
| 517 |
+
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
| 518 |
+
).squeeze(2)
|
| 519 |
+
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
| 520 |
+
dtype=torch.float64
|
| 521 |
+
) / mask.sum(dim=1).to(dtype=torch.float64)
|
| 522 |
+
|
| 523 |
+
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
| 524 |
+
outputs_neg = model(
|
| 525 |
+
**{
|
| 526 |
+
"input_ids": inputs["rejected_input_ids"],
|
| 527 |
+
"attention_mask": inputs["rejected_attention_mask"],
|
| 528 |
+
"labels": inputs["rejected_labels"],
|
| 529 |
+
},
|
| 530 |
+
output_hidden_states=True,
|
| 531 |
+
)
|
| 532 |
+
outputs_pos = model(
|
| 533 |
+
**{
|
| 534 |
+
"input_ids": inputs["input_ids"],
|
| 535 |
+
"attention_mask": inputs["attention_mask"],
|
| 536 |
+
"labels": inputs["labels"],
|
| 537 |
+
},
|
| 538 |
+
output_hidden_states=True,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Calculate NLL loss
|
| 542 |
+
pos_loss = self.orpo_compute_custom_loss(
|
| 543 |
+
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Calculate Log Probability
|
| 547 |
+
pos_prob = self.orpo_compute_logps(
|
| 548 |
+
prompt_attention_mask=inputs["prompt_attention_mask"],
|
| 549 |
+
chosen_inputs=inputs["input_ids"],
|
| 550 |
+
chosen_attention_mask=inputs["attention_mask"],
|
| 551 |
+
logits=outputs_pos.logits,
|
| 552 |
+
)
|
| 553 |
+
neg_prob = self.orpo_compute_logps(
|
| 554 |
+
prompt_attention_mask=inputs["prompt_attention_mask"],
|
| 555 |
+
chosen_inputs=inputs["rejected_input_ids"],
|
| 556 |
+
chosen_attention_mask=inputs["rejected_attention_mask"],
|
| 557 |
+
logits=outputs_neg.logits,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Calculate log odds
|
| 561 |
+
log_odds = (pos_prob - neg_prob) - (
|
| 562 |
+
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
| 563 |
+
)
|
| 564 |
+
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
| 565 |
+
ratio = torch.log(sig_ratio)
|
| 566 |
+
|
| 567 |
+
# Calculate the Final Loss
|
| 568 |
+
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
| 569 |
+
dtype=torch.bfloat16
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
metrics = {}
|
| 573 |
+
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
| 574 |
+
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
| 575 |
+
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
| 576 |
+
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
| 577 |
+
self.store_metrics(metrics, train_eval="train")
|
| 578 |
+
|
| 579 |
+
return (loss, outputs_pos) if return_outputs else loss
|
| 580 |
+
|
| 581 |
@wraps(Trainer.push_to_hub)
|
| 582 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 583 |
"""
|
|
|
|
| 638 |
|
| 639 |
return res
|
| 640 |
|
| 641 |
+
def log(self, logs: Dict[str, float]) -> None:
|
| 642 |
+
"""
|
| 643 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
logs (`Dict[str, float]`):
|
| 647 |
+
The values to log.
|
| 648 |
+
"""
|
| 649 |
+
# logs either has 'loss' or 'eval_loss'
|
| 650 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 651 |
+
# Add averaged stored metrics to logs
|
| 652 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 653 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 654 |
+
del self._stored_metrics[train_eval]
|
| 655 |
+
return super().log(logs)
|
| 656 |
+
|
| 657 |
+
def store_metrics(
|
| 658 |
+
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
| 659 |
+
) -> None:
|
| 660 |
+
for key, value in metrics.items():
|
| 661 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 662 |
+
|
| 663 |
|
| 664 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 665 |
"""
|
|
|
|
| 1036 |
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
| 1037 |
training_arguments_kwargs["dataloader_drop_last"] = True
|
| 1038 |
|
| 1039 |
+
if self.cfg.remove_unused_columns is not None:
|
| 1040 |
+
training_arguments_kwargs[
|
| 1041 |
+
"remove_unused_columns"
|
| 1042 |
+
] = self.cfg.remove_unused_columns
|
| 1043 |
+
|
| 1044 |
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
| 1045 |
# no eval set, so don't eval
|
| 1046 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
|
|
| 1208 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
| 1209 |
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
| 1210 |
|
| 1211 |
+
if self.cfg.rl == "orpo":
|
| 1212 |
+
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
| 1213 |
+
|
| 1214 |
if self.cfg.neftune_noise_alpha is not None:
|
| 1215 |
training_arguments_kwargs[
|
| 1216 |
"neftune_noise_alpha"
|
src/axolotl/prompt_strategies/base.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
module for base dataset transform strategies
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
LOG = logging.getLogger("axolotl")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load(strategy, cfg, module_base=None, **kwargs):
|
| 12 |
+
try:
|
| 13 |
+
load_fn = strategy.split(".")[-1]
|
| 14 |
+
strategy = ".".join(strategy.split(".")[:-1])
|
| 15 |
+
mod = importlib.import_module(f".{strategy}", module_base)
|
| 16 |
+
func = getattr(mod, load_fn)
|
| 17 |
+
return func(cfg, **kwargs)
|
| 18 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 19 |
+
LOG.warning(f"unable to load strategy {strategy}")
|
| 20 |
+
return None
|
src/axolotl/prompt_strategies/dpo/__init__.py
CHANGED
|
@@ -1,20 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
module for DPO style dataset transform strategies
|
| 3 |
"""
|
|
|
|
| 4 |
|
| 5 |
-
import
|
| 6 |
-
import logging
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def load(strategy, cfg, **kwargs):
|
| 12 |
-
try:
|
| 13 |
-
load_fn = strategy.split(".")[-1]
|
| 14 |
-
strategy = ".".join(strategy.split(".")[:-1])
|
| 15 |
-
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
| 16 |
-
func = getattr(mod, load_fn)
|
| 17 |
-
return func(cfg, **kwargs)
|
| 18 |
-
except Exception: # pylint: disable=broad-exception-caught
|
| 19 |
-
LOG.warning(f"unable to load strategy {strategy}")
|
| 20 |
-
return None
|
|
|
|
| 1 |
"""
|
| 2 |
module for DPO style dataset transform strategies
|
| 3 |
"""
|
| 4 |
+
from functools import partial
|
| 5 |
|
| 6 |
+
from ..base import load as load_base
|
|
|
|
| 7 |
|
| 8 |
+
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/prompt_strategies/orpo/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
module for ORPO style dataset transform strategies
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
from ..base import load as load_base
|
| 8 |
+
|
| 9 |
+
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
src/axolotl/prompt_strategies/orpo/chat_template.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""chatml prompt tokenization strategy for ORPO"""
|
| 2 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
| 7 |
+
from axolotl.prompters import Prompter
|
| 8 |
+
from axolotl.utils.chat_templates import chat_templates
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Message(BaseModel):
|
| 12 |
+
"""message/turn"""
|
| 13 |
+
|
| 14 |
+
role: str
|
| 15 |
+
content: str
|
| 16 |
+
label: Optional[bool] = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MessageList(BaseModel):
|
| 20 |
+
"""conversation"""
|
| 21 |
+
|
| 22 |
+
messages: List[Message]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load(
|
| 26 |
+
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
|
| 27 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
| 28 |
+
"""
|
| 29 |
+
chatml transforms for datasets with system, input, chosen, rejected
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
chat_template = chat_templates("chatml")
|
| 33 |
+
if ds_cfg and "chat_template" in ds_cfg:
|
| 34 |
+
chat_template = ds_cfg["chat_template"]
|
| 35 |
+
try:
|
| 36 |
+
chat_template = chat_templates(chat_template)
|
| 37 |
+
except ValueError:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
return ORPOTokenizingStrategy(
|
| 41 |
+
ORPOPrompter(chat_template, tokenizer),
|
| 42 |
+
tokenizer,
|
| 43 |
+
cfg.train_on_inputs,
|
| 44 |
+
cfg.sequence_len,
|
| 45 |
+
dataset_parser=ORPODatasetParsingStrategy(),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ORPODatasetParsingStrategy:
|
| 50 |
+
"""Strategy to parse chosen rejected dataset into messagelist"""
|
| 51 |
+
|
| 52 |
+
def get_chosen_conversation_thread(self, prompt) -> MessageList:
|
| 53 |
+
"""Dataset structure mappings"""
|
| 54 |
+
|
| 55 |
+
messages: List[Message] = []
|
| 56 |
+
if system := prompt.get("system", None):
|
| 57 |
+
messages.append(Message(role="system", content=system, label=False))
|
| 58 |
+
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
| 59 |
+
messages.append(
|
| 60 |
+
Message(
|
| 61 |
+
role="assistant", content=prompt["chosen"][1]["content"], label=True
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
return MessageList(messages=messages)
|
| 65 |
+
|
| 66 |
+
def get_rejected_conversation_thread(self, prompt) -> MessageList:
|
| 67 |
+
"""Dataset structure mappings"""
|
| 68 |
+
|
| 69 |
+
messages: List[Message] = []
|
| 70 |
+
if system := prompt.get("system", None):
|
| 71 |
+
messages.append(Message(role="system", content=system, label=False))
|
| 72 |
+
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
| 73 |
+
messages.append(
|
| 74 |
+
Message(
|
| 75 |
+
role="assistant", content=prompt["rejected"][1]["content"], label=True
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
return MessageList(messages=messages)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
| 82 |
+
"""
|
| 83 |
+
rejected_input_ids
|
| 84 |
+
input_ids
|
| 85 |
+
rejected_attention_mask
|
| 86 |
+
attention_mask
|
| 87 |
+
rejected_labels
|
| 88 |
+
labels
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
*args,
|
| 94 |
+
dataset_parser=None,
|
| 95 |
+
**kwargs,
|
| 96 |
+
):
|
| 97 |
+
super().__init__(*args, **kwargs)
|
| 98 |
+
self.dataset_parser = dataset_parser
|
| 99 |
+
|
| 100 |
+
def tokenize_prompt(self, prompt):
|
| 101 |
+
# pass the rejected prompt/row to the Prompter to get the formatted prompt
|
| 102 |
+
prompt_len = 0
|
| 103 |
+
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
|
| 104 |
+
prompt
|
| 105 |
+
)
|
| 106 |
+
input_ids = []
|
| 107 |
+
labels = []
|
| 108 |
+
for _, (part, label) in enumerate(
|
| 109 |
+
self.prompter.build_prompt(rejected_message_list)
|
| 110 |
+
):
|
| 111 |
+
if not part:
|
| 112 |
+
continue
|
| 113 |
+
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
| 114 |
+
prev_idx = len(input_ids)
|
| 115 |
+
input_ids += _input_ids[prev_idx:]
|
| 116 |
+
if label:
|
| 117 |
+
labels += input_ids[prev_idx:]
|
| 118 |
+
else:
|
| 119 |
+
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
| 120 |
+
prompt_len = len(input_ids)
|
| 121 |
+
# remap the input_ids, attention_mask and labels
|
| 122 |
+
rejected_input_ids = input_ids
|
| 123 |
+
rejected_labels = labels
|
| 124 |
+
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
| 125 |
+
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
|
| 126 |
+
input_ids = []
|
| 127 |
+
labels = []
|
| 128 |
+
for _, (part, label) in enumerate(
|
| 129 |
+
self.prompter.build_prompt(chosen_message_list)
|
| 130 |
+
):
|
| 131 |
+
if not part:
|
| 132 |
+
continue
|
| 133 |
+
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
| 134 |
+
prev_idx = len(input_ids)
|
| 135 |
+
input_ids += _input_ids[prev_idx:]
|
| 136 |
+
if label:
|
| 137 |
+
labels += input_ids[prev_idx:]
|
| 138 |
+
else:
|
| 139 |
+
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"rejected_input_ids": rejected_input_ids,
|
| 143 |
+
"rejected_labels": rejected_labels,
|
| 144 |
+
"rejected_attention_mask": [1] * len(rejected_labels),
|
| 145 |
+
"input_ids": input_ids,
|
| 146 |
+
"labels": labels,
|
| 147 |
+
"attention_mask": [1] * len(labels),
|
| 148 |
+
"prompt_attention_mask": [1] * prompt_len
|
| 149 |
+
+ [0] * (len(labels) - prompt_len),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ORPOPrompter(Prompter):
|
| 154 |
+
"""Single Turn prompter for ORPO"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, chat_template, tokenizer):
|
| 157 |
+
self.chat_template = chat_template
|
| 158 |
+
self.tokenizer = tokenizer
|
| 159 |
+
|
| 160 |
+
def build_prompt(
|
| 161 |
+
self,
|
| 162 |
+
message_list: MessageList,
|
| 163 |
+
) -> Generator[Tuple[str, bool], None, None]:
|
| 164 |
+
conversation = []
|
| 165 |
+
for message in message_list.messages:
|
| 166 |
+
conversation.append(message.model_dump())
|
| 167 |
+
if message.role == "system":
|
| 168 |
+
yield self.tokenizer.apply_chat_template(
|
| 169 |
+
conversation,
|
| 170 |
+
add_generation_prompt=False,
|
| 171 |
+
chat_template=self.chat_template,
|
| 172 |
+
tokenize=False,
|
| 173 |
+
), False
|
| 174 |
+
if message.role == "user":
|
| 175 |
+
yield self.tokenizer.apply_chat_template(
|
| 176 |
+
conversation,
|
| 177 |
+
add_generation_prompt=True,
|
| 178 |
+
chat_template=self.chat_template,
|
| 179 |
+
tokenize=False,
|
| 180 |
+
), False
|
| 181 |
+
if message.role == "assistant":
|
| 182 |
+
yield self.tokenizer.apply_chat_template(
|
| 183 |
+
conversation,
|
| 184 |
+
add_generation_prompt=False,
|
| 185 |
+
chat_template=self.chat_template,
|
| 186 |
+
tokenize=False,
|
| 187 |
+
), True
|
src/axolotl/train.py
CHANGED
|
@@ -85,7 +85,7 @@ def train(
|
|
| 85 |
model.generation_config.do_sample = True
|
| 86 |
|
| 87 |
model_ref = None
|
| 88 |
-
if cfg.rl:
|
| 89 |
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
| 90 |
# use built-in trl autounwrap
|
| 91 |
LOG.debug("Passing model_ref: None to RL trainer")
|
|
|
|
| 85 |
model.generation_config.do_sample = True
|
| 86 |
|
| 87 |
model_ref = None
|
| 88 |
+
if cfg.rl and cfg.rl != "orpo":
|
| 89 |
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
| 90 |
# use built-in trl autounwrap
|
| 91 |
LOG.debug("Passing model_ref: None to RL trainer")
|
src/axolotl/utils/chat_templates.py
CHANGED
|
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
|
|
| 21 |
templates = {
|
| 22 |
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
| 23 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
| 24 |
-
"chatml": "{% if
|
| 25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
| 26 |
}
|
| 27 |
|
|
|
|
| 21 |
templates = {
|
| 22 |
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
| 23 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
| 24 |
+
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
| 26 |
}
|
| 27 |
|
src/axolotl/utils/config/__init__.py
CHANGED
|
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
|
|
| 191 |
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
| 192 |
)
|
| 193 |
cfg.datasets[idx].conversation = "chatml"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|
|
|
| 191 |
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
| 192 |
)
|
| 193 |
cfg.datasets[idx].conversation = "chatml"
|
| 194 |
+
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
| 195 |
+
LOG.info(
|
| 196 |
+
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
| 197 |
+
)
|
| 198 |
+
cfg.datasets[idx].chat_template = "chatml"
|
| 199 |
|
| 200 |
|
| 201 |
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -124,6 +124,7 @@ class RLType(str, Enum):
|
|
| 124 |
dpo = "dpo" # pylint: disable=invalid-name
|
| 125 |
ipo = "ipo" # pylint: disable=invalid-name
|
| 126 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
class ChatTemplate(str, Enum):
|
|
@@ -431,6 +432,8 @@ class AxolotlInputConfig(
|
|
| 431 |
dataloader_prefetch_factor: Optional[int] = None
|
| 432 |
dataloader_drop_last: Optional[bool] = None
|
| 433 |
|
|
|
|
|
|
|
| 434 |
push_dataset_to_hub: Optional[str] = None
|
| 435 |
hf_use_auth_token: Optional[bool] = None
|
| 436 |
|
|
@@ -515,6 +518,8 @@ class AxolotlInputConfig(
|
|
| 515 |
|
| 516 |
neftune_noise_alpha: Optional[float] = None
|
| 517 |
|
|
|
|
|
|
|
| 518 |
max_memory: Optional[
|
| 519 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
| 520 |
] = None
|
|
|
|
| 124 |
dpo = "dpo" # pylint: disable=invalid-name
|
| 125 |
ipo = "ipo" # pylint: disable=invalid-name
|
| 126 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
| 127 |
+
orpo = "orpo" # pylint: disable=invalid-name
|
| 128 |
|
| 129 |
|
| 130 |
class ChatTemplate(str, Enum):
|
|
|
|
| 432 |
dataloader_prefetch_factor: Optional[int] = None
|
| 433 |
dataloader_drop_last: Optional[bool] = None
|
| 434 |
|
| 435 |
+
remove_unused_columns: Optional[bool] = None
|
| 436 |
+
|
| 437 |
push_dataset_to_hub: Optional[str] = None
|
| 438 |
hf_use_auth_token: Optional[bool] = None
|
| 439 |
|
|
|
|
| 518 |
|
| 519 |
neftune_noise_alpha: Optional[float] = None
|
| 520 |
|
| 521 |
+
orpo_alpha: Optional[float] = None
|
| 522 |
+
|
| 523 |
max_memory: Optional[
|
| 524 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
| 525 |
] = None
|
src/axolotl/utils/freeze.py
CHANGED
|
@@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
import re
|
| 6 |
-
from typing import Callable, List, Tuple
|
| 7 |
|
| 8 |
from axolotl.utils.distributed import is_main_process
|
| 9 |
|
|
@@ -99,7 +99,7 @@ def _invert_ranges(
|
|
| 99 |
|
| 100 |
|
| 101 |
def _merge_ranges(
|
| 102 |
-
given_ranges: List[Tuple[int, int
|
| 103 |
) -> List[Tuple[int, int]]:
|
| 104 |
"""
|
| 105 |
Merges overlapping ranges and sorts the given ranges.
|
|
@@ -194,7 +194,9 @@ class LayerNamePattern:
|
|
| 194 |
"""
|
| 195 |
return self.name_regex.match(name) is not None
|
| 196 |
|
| 197 |
-
def _parse_pattern(
|
|
|
|
|
|
|
| 198 |
"""
|
| 199 |
Extracts the range pattern from the given pattern.
|
| 200 |
|
|
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
import re
|
| 6 |
+
from typing import Callable, List, Tuple, Union
|
| 7 |
|
| 8 |
from axolotl.utils.distributed import is_main_process
|
| 9 |
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def _merge_ranges(
|
| 102 |
+
given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
|
| 103 |
) -> List[Tuple[int, int]]:
|
| 104 |
"""
|
| 105 |
Merges overlapping ranges and sorts the given ranges.
|
|
|
|
| 194 |
"""
|
| 195 |
return self.name_regex.match(name) is not None
|
| 196 |
|
| 197 |
+
def _parse_pattern(
|
| 198 |
+
self, pattern: str
|
| 199 |
+
) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
|
| 200 |
"""
|
| 201 |
Extracts the range pattern from the given pattern.
|
| 202 |
|
tests/test_prompt_tokenizers.py
CHANGED
|
@@ -8,7 +8,8 @@ from pathlib import Path
|
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
import pytest
|
| 11 |
-
from
|
|
|
|
| 12 |
|
| 13 |
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
| 14 |
from axolotl.prompt_strategies.alpaca_w_system import (
|
|
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
|
|
| 19 |
Llama2ChatPrompter,
|
| 20 |
LLama2ChatTokenizingStrategy,
|
| 21 |
)
|
|
|
|
| 22 |
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
| 23 |
from axolotl.prompt_tokenizers import (
|
| 24 |
AlpacaPromptTokenizingStrategy,
|
| 25 |
ShareGPTPromptTokenizingStrategy,
|
| 26 |
)
|
| 27 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|
|
|
| 28 |
|
| 29 |
LOG = logging.getLogger("axolotl")
|
| 30 |
|
|
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|
| 446 |
)
|
| 447 |
|
| 448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
if __name__ == "__main__":
|
| 450 |
unittest.main()
|
|
|
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
import pytest
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
| 13 |
|
| 14 |
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
| 15 |
from axolotl.prompt_strategies.alpaca_w_system import (
|
|
|
|
| 20 |
Llama2ChatPrompter,
|
| 21 |
LLama2ChatTokenizingStrategy,
|
| 22 |
)
|
| 23 |
+
from axolotl.prompt_strategies.orpo.chat_template import load
|
| 24 |
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
| 25 |
from axolotl.prompt_tokenizers import (
|
| 26 |
AlpacaPromptTokenizingStrategy,
|
| 27 |
ShareGPTPromptTokenizingStrategy,
|
| 28 |
)
|
| 29 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
| 30 |
+
from axolotl.utils.dict import DictDefault
|
| 31 |
|
| 32 |
LOG = logging.getLogger("axolotl")
|
| 33 |
|
|
|
|
| 449 |
)
|
| 450 |
|
| 451 |
|
| 452 |
+
class OrpoTokenizationTest(unittest.TestCase):
|
| 453 |
+
"""test case for the ORPO tokenization"""
|
| 454 |
+
|
| 455 |
+
def setUp(self) -> None:
|
| 456 |
+
# pylint: disable=duplicate-code
|
| 457 |
+
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
| 458 |
+
tokenizer.add_special_tokens(
|
| 459 |
+
{
|
| 460 |
+
"eos_token": AddedToken(
|
| 461 |
+
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
| 462 |
+
)
|
| 463 |
+
}
|
| 464 |
+
)
|
| 465 |
+
tokenizer.add_tokens(
|
| 466 |
+
[
|
| 467 |
+
AddedToken(
|
| 468 |
+
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
| 469 |
+
),
|
| 470 |
+
]
|
| 471 |
+
)
|
| 472 |
+
self.tokenizer = tokenizer
|
| 473 |
+
self.dataset = load_dataset(
|
| 474 |
+
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
| 475 |
+
).select([0])
|
| 476 |
+
|
| 477 |
+
def test_orpo_integration(self):
|
| 478 |
+
strat = load(
|
| 479 |
+
self.tokenizer,
|
| 480 |
+
DictDefault({"train_on_inputs": False}),
|
| 481 |
+
DictDefault({"chat_template": "chatml"}),
|
| 482 |
+
)
|
| 483 |
+
res = strat.tokenize_prompt(self.dataset[0])
|
| 484 |
+
assert "rejected_input_ids" in res
|
| 485 |
+
assert "rejected_labels" in res
|
| 486 |
+
assert "input_ids" in res
|
| 487 |
+
assert "labels" in res
|
| 488 |
+
assert "prompt_attention_mask" in res
|
| 489 |
+
|
| 490 |
+
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
| 491 |
+
assert len(res["input_ids"]) == len(res["labels"])
|
| 492 |
+
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
| 493 |
+
|
| 494 |
+
assert res["rejected_labels"][0] == -100
|
| 495 |
+
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
| 496 |
+
|
| 497 |
+
assert res["labels"][0] == -100
|
| 498 |
+
assert res["input_ids"][-1] == res["labels"][-1]
|
| 499 |
+
|
| 500 |
+
assert res["prompt_attention_mask"][0] == 1
|
| 501 |
+
assert res["prompt_attention_mask"][-1] == 0
|
| 502 |
+
|
| 503 |
+
|
| 504 |
if __name__ == "__main__":
|
| 505 |
unittest.main()
|