|
|
|
|
|
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset. |
|
|
|
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py). |
|
|
|
|
|
|
|
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model. |
|
The following code-snippet takes care of all the data pre-processing and training for you: |
|
|
|
```python |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
|
|
dataset = load_dataset("imdb", split="train") |
|
|
|
trainer = SFTTrainer( |
|
"facebook/opt-350m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
max_seq_length=512, |
|
) |
|
trainer.train() |
|
``` |
|
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. |
|
|
|
You can also construct a model outside of the trainer and pass it as follows: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
|
|
dataset = load_dataset("imdb", split="train") |
|
|
|
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") |
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
max_seq_length=512, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer |
|
|
|
|
|
|
|
|
|
|
|
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`. |
|
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from datasets import load_dataset |
|
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM |
|
|
|
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train") |
|
|
|
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
|
|
|
def formatting_prompts_func(example): |
|
output_texts = [] |
|
for i in range(len(example['instruction'])): |
|
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" |
|
output_texts.append(text) |
|
return output_texts |
|
|
|
response_template = " ### Answer:" |
|
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) |
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
formatting_func=formatting_prompts_func, |
|
data_collator=collator, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from datasets import load_dataset |
|
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM |
|
|
|
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") |
|
|
|
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") |
|
|
|
instruction_template = "### Human:" |
|
response_template = "### Assistant:" |
|
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False) |
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
data_collator=collator, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation. |
|
|
|
|
|
|
|
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example: |
|
|
|
```python |
|
from transformers import AutoTokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
|
|
|
def print_tokens_with_ids(txt): |
|
tokens = tokenizer.tokenize(txt, add_special_tokens=False) |
|
token_ids = tokenizer.encode(txt, add_special_tokens=False) |
|
print(list(zip(tokens, token_ids))) |
|
|
|
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?""" |
|
print_tokens_with_ids(prompt) |
|
|
|
response_template = "### Assistant:" |
|
print_tokens_with_ids(response_template) |
|
``` |
|
|
|
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently: |
|
|
|
- Text (with context): `[2277, 29937, 4007, 22137, 29901]` |
|
- `response_template` (without context): `[835, 4007, 22137, 29901]` |
|
|
|
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text: |
|
|
|
``` |
|
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...]) |
|
``` |
|
|
|
|
|
To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example: |
|
|
|
```python |
|
response_template_with_context = "\n### Assistant:" |
|
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] |
|
|
|
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer) |
|
``` |
|
|
|
|
|
|
|
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. |
|
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows: |
|
```bash |
|
Below is an instruction ... |
|
|
|
|
|
{prompt} |
|
|
|
|
|
{completion} |
|
``` |
|
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run: |
|
```python |
|
... |
|
def formatting_prompts_func(example): |
|
output_texts = [] |
|
for i in range(len(example['question'])): |
|
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}" |
|
output_texts.append(text) |
|
return output_texts |
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
formatting_func=formatting_prompts_func, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763) |
|
|
|
|
|
|
|
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor. |
|
|
|
```python |
|
... |
|
|
|
trainer = SFTTrainer( |
|
"facebook/opt-350m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
packing=True |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing. |
|
|
|
|
|
|
|
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example: |
|
|
|
```python |
|
def formatting_func(example): |
|
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" |
|
return text |
|
|
|
trainer = SFTTrainer( |
|
"facebook/opt-350m", |
|
train_dataset=dataset, |
|
packing=True, |
|
formatting_func=formatting_func |
|
) |
|
|
|
trainer.train() |
|
``` |
|
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information. |
|
|
|
### Control over the pretrained model |
|
|
|
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to |
|
|
|
```python |
|
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16) |
|
``` |
|
|
|
```python |
|
... |
|
|
|
trainer = SFTTrainer( |
|
"facebook/opt-350m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
model_init_kwargs={ |
|
"torch_dtype": torch.bfloat16, |
|
}, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
Note that all keyword arguments of `from_pretrained()` are supported. |
|
|
|
### Training adapters |
|
|
|
We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model |
|
|
|
```python |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
from peft import LoraConfig |
|
|
|
dataset = load_dataset("imdb", split="train") |
|
|
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
trainer = SFTTrainer( |
|
"EleutherAI/gpt-neo-125m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
peft_config=peft_config |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only: |
|
```python |
|
class PeftSavingCallback(TrainerCallback): |
|
def on_save(self, args, state, control, **kwargs): |
|
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") |
|
kwargs["model"].save_pretrained(checkpoint_path) |
|
|
|
if "pytorch_model.bin" in os.listdir(checkpoint_path): |
|
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) |
|
``` |
|
If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training. |
|
```python |
|
... |
|
|
|
callbacks = [YourCustomCallback(), PeftSavingCallback()] |
|
|
|
trainer = SFTTrainer( |
|
"EleutherAI/gpt-neo-125m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
peft_config=peft_config, |
|
callbacks=callbacks |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed. |
|
|
|
### Training adapters with base 8 bit models |
|
|
|
For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example: |
|
|
|
```python |
|
... |
|
|
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"EleutherAI/gpt-neo-125m", |
|
load_in_8bit=True, |
|
device_map="auto", |
|
) |
|
|
|
trainer = SFTTrainer( |
|
model, |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
peft_config=peft_config, |
|
) |
|
|
|
trainer.train() |
|
``` |
|
|
|
## Using Flash Attention and Flash Attention 2 |
|
|
|
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code. |
|
First, to make sure you have all the latest features from transformers, install transformers from source |
|
|
|
```bash |
|
pip install -U git+https://github.com/huggingface/transformers.git |
|
``` |
|
|
|
Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision) |
|
Note also both features are perfectly compatible with other tools such as quantization. |
|
|
|
### Using Flash-Attention 1 |
|
|
|
For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package: |
|
|
|
```bash |
|
pip install -U optimum |
|
``` |
|
|
|
Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager: |
|
|
|
```diff |
|
... |
|
|
|
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): |
|
trainer.train() |
|
``` |
|
|
|
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration. |
|
|
|
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB. |
|
|
|
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step | |
|
|----------------|-------------------|-------------|------------|------------------------| |
|
| x | facebook/opt-350m | 2048 | 8 | ~59.1s | |
|
| | facebook/opt-350m | 2048 | 8 | **OOM** | |
|
| x | facebook/opt-350m | 2048 | 4 | ~30.3s | |
|
| | facebook/opt-350m | 2048 | 4 | ~148.9s | |
|
|
|
### Using Flash Attention-2 |
|
|
|
To use Flash Attention 2, first install the latest `flash-attn` package: |
|
|
|
```bash |
|
pip install -U flash-attn |
|
``` |
|
|
|
And add `use_flash_attention_2=True` when calling `from_pretrained`: |
|
|
|
```python |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
load_in_4bit=True, |
|
use_flash_attention_2=True |
|
) |
|
``` |
|
|
|
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device. |
|
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized. |
|
|
|
In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens. |
|
|
|
|
|
|
|
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper: |
|
|
|
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune. |
|
|
|
<div style="text-align: center"> |
|
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png"> |
|
</div> |
|
|
|
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer. |
|
|
|
```python |
|
from datasets import load_dataset |
|
from trl import SFTTrainer |
|
|
|
dataset = load_dataset("imdb", split="train") |
|
|
|
trainer = SFTTrainer( |
|
"facebook/opt-350m", |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
max_seq_length=512, |
|
neftune_noise_alpha=5, |
|
) |
|
trainer.train() |
|
``` |
|
|
|
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench. |
|
|
|
<div style="text-align: center"> |
|
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png"> |
|
</div> |
|
|
|
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. |
|
|
|
|
|
Pay attention to the following best practices when training a model with that trainer: |
|
|
|
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training. |
|
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. |
|
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. |
|
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. |
|
|
|
|
|
|
|
[[autodoc]] SFTTrainer |
|
|
|
|
|
|
|
[[autodoc]] trainer.ConstantLengthDataset |
|
|