RLHF in 2024 with DPO & Hugging Face
This blog post walks you through how to use DPO to improve open LLMs using Hugging Face TRL, Transformers & datasets in 2024.
Research and experiments suggest that DPO should only be applied after SFT. This means we need an already fine-tuned LLM, which can be aligned with DPO. In this example we will use cognitivecomputations/dolphin-2.1-mistral-7b a fine-tuned Mistral 7B with ChatML template.
- Setup development environment
- Create and prepare the preference dataset
- Align LLM with TRL and the DPOTrainer
- Test LLM (vibe-check)
- Evaluate open LLMs on MT-Bench
Note: This example is designed to be an introduction to DPO and TRL. It is build for a single GPU environment to guide you through the process. For production use, you should consider using a distributed environment. It should be possible to run the example on a single GPU with at least 24GB of memory by reducing the training arguments, with batch size, max seq length and run evaluation after the training.
1. Setup development environment
Our first step is to install Hugging Face Libraries and Pytorch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.
If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer, you can use Flash Attention. FlashAttention can accelerate training time up to 3x.
*Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of MAX_JOBS
.
Installing flash attention can take quite a bit of time (10-45 minutes).
We will use the Hugging Face Hub as a remote model storage and automatically push our model, logs and information to the Hub during training. You must register on the Hugging Face for this. After you have an account, we will use the login
util from the huggingface_hub
package to log into our account and store our token (access key) on the disk.
2. Create and prepare the dataset
Improving the helpfulness or quality of LLMs through Aligning methods like DPO doesn’t come for free. Compared to traditional supervised fine-tuning (SFT) alignment methods require preference data. Preference data is crucial as it serves as a proxy against which the model's outputs are evaluated and aligned. A typical DPO dataset includes a triplet out of prompt, chosen, and rejected response. There are several ways to create such a dataset, including:
- Using existing open-source datasets, e.g., SHP
- Using LLMs to create synthetic preferences, e.g., Ultrafeedback
- Using Humans to create datasets, e.g., HH
- Using a combination of the above methods, e.g., Orca DPO
Each method has advantages and disadvantages and depends on the budget, time, and quality requirements.
It's important to recognize that preference datasets can inherently reflect the biases of the human/AI they are based on. To ensure broader applicability and fairness, it's crucial to incorporate a diverse range of feedback in creating these datasets.
In our example, we will use the argilla/ultrafeedback-binarized-preferences-cleaned dataset. The best DPO dataset represents the real-world preferences of your users or customers. If you don’t have collected preferences yet, start with your existing SFT data and use different sizes/quality LLMs to generate feedback. This method was used to create the Orca DPO dataset, where GPT-4 was used for the accepted responses and Llama 70B Chat for the rejected responses. A DPO dataset will have the following format
The <pompt + good response>
and <prompt + worse response>
are representend in the conversational
format as:
Note: If the dataset includes multiple turns you need to make sure that only the last turn between chosen and rejected is different. If not, you must reduce the conversation until only the last assistant turn is different.
The DPOTrainer expects the inputs as triples of (prompt, chosen, rejected), where chosen
and rejected
are the final turn of a dialogue and the prompt
is N-1 turns. Those inputs also need to be already formated with the tempalte of the model, e.g. <|im_start|>user\nINSTRUCTION\n<|im_end|>\n<|im_start|>assistant\n...
.
In our example we are going to load our open-source dataset using the 🤗 Datasets library and then convert it into the correct format. The argilla/ultrafeedback-binarized-preferences-cleaned already comes with the DPO format (chosen/rejected). This means we can create our triplet and templetite it usng a tokenizer
and the apply_chat_template
methoh. We are randomly downsampling the dataset to 11,000 training samples and 2,750 test samples
Note: This step can be different for your use case. For example, if you might need to create the conversational
format and concate the prompt and chosen/rejected response, e.g. Human:\n ... Assistant:\n
.
3. Align LLM with TRL and the DPOTrainer
TRL supports the DPO through a dedicated DPOTrainer for alinging LLMs from preference data, as described in Direct Preference Optimization: Your Language Model is Secretly a Reward Model. The DPOTrainer
is a subclass of the Trainer
from the transformers
library and supports all the same features, including logging, evaluation, and checkpointing.
One big difference to SFT is that for DPO we need an additional Reference Model, which is used for KL-Divergence to help stabilize the training. The Reference Model is normally the same model as the one we are training, but frozen. This means for DPO you need additional memory and compute resources. To keep our example efficient we will use PEFT and adatpers. We load your fine-tuned and then add a new trainable adapters. This means that we will only tune adapters and not the whole model using DPO. The origian model will be then used as reference model itself. If you want to train all parameter with DPO you need to provide a model
and `reference_model, but this requires more memory and compute resources.
Lets start by loading our saved datasets from disk.
We are going to train cognitivecomputations/dolphin-2.1-mistral-7b. Dolphin is a fine-tuned Mistral 7B with ChatML template support system messages. You can easily swap out the model for another model, e.g. Mistral or Mixtral models, TII Falcon, or any other LLMs by changing our model_id
variable.
Note: Be aware the bigger the model the more memory it will require. In our example we will use the 7B version, which can be tuned on 24GB GPUs. If you have a smaller GPU.
The first step is to load the model in int-4 using bitsandbytes
and then add
Compared to the SFTTrainer
the DPOTrainer has two parameter related to dataset sizing with max_prompt_length
and max_length
. The max_prompt_length
is the maximum length of the prompt and the max_length
is the maximum length of the prompt + chosen or rejected response. Those are used for tokenization, padding and trunctation. This means if we set those wrongly our data will be potentially cut off, but if we set them too high we will waste memory and time.
The Alignment Handbook when with the max_prompt_length
of 512 and max_length
of 1024 combining it with the truncation side left (90% of data samples where in that range). Truncation side left means the beginning will be removed so we keep the important assistant response. In our example we want to cover the ~97% percentile and filter out longer samples, rather than truncating.
Note: You could reduce the max_seq_length
to 1512
this would lead to a memory reduction and then increase the batch_size.
The DPOTrainer
supports a native integration with peft
, which makes it super easy to efficiently align LLMs using, e.g. QLoRA. We only need to create our LoraConfig
and provide it to the trainer. Our LoraConfig
parameters are the same as for the SFT example.
Before we can start our training we need to define the hyperparameters (TrainingArguments
) & DPO parameters.
Based on the Alignment Handbook we know that we need to use a ~10-100x smaller learning rate for DPO compared to SFT. In our example we reduce the learning rate from 2e-4 (SFT) to 5e-5 (DPO) or 40x smaller.
Another important parameter is the beta
parameter, which is used to control the strength of the alignment. The bigger the beta
is typically something in the range of 0.1 to 0.5. A higher beta means less divergence from the initial reference model or the text generations are very similar in terms of their probability distributions. In terms of training length, we go with 1
epoch, which is a good starting point. There is no rule of thumb for the number of epochs, it is also related to the number of epochs used for fine-tuning.
We now have every building block we need to create our DPOTrainer
to start then training our model.
Start training our model by calling the train()
method on our Trainer
instance. This will start the training loop and train our model for 2 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.
Note: During the training we want to minimize loss and grow reward/margins metrics. Keep an eye on the reward/margins metrics, if they are not growing you might need to increase the beta
parameter or adjust the learning_rate
.
The training with Flash Attention for 1 epochs with a dataset of ~10k samples took ~01:30:00 on 1x H100 GPU. You should be able to run the training on a g5.2xlarge
instance by reducing the batch_size (est. to 1) and maybe the max_seq_length (est. to 1512).
Optional: Merge LoRA adapter in to the original model
When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload
method and then save the model with the save_pretrained
method. This will save a default model, which can be used for inference.
Note: You might require > 30GB CPU Memory.
4. Test LLM (vibe-check)
After the training is done we want to test and evaluate or model. Evaluating Generative AI models in an open-ended way is not a trivial since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out Evaluate LLMs and RAG a practical example using Langchain and Hugging Face blog post. Especially, when using RLHF techniques like DPO, it's important to "vibe-check" the model.
This means we want to manually check if the responses are more aligned with what our users or customers want. This could mean that we need to check if the responses are more helpful, more accurate, more engaging, or more informative as before. A good test here is if you have data from your SFT or previous LLMs, you can compare the outputs and see if the new model is better.
In our case we just check a few examples and see if the model generates helpful responses using unseen prompts.
We randomely select prompts from the teknium/OpenHermes-2.5 dataset and a Hugging Face special.
Lets iterate over the prompts and generate a response using the generate
method.
5. Evaluate open LLMs on MT-Bench
For our use case we will use MT-Bench. MT-Bench is a Benchmark designed by LMSYS to test the conversation and instruction-following capabilities of large language models (LLMs). It evaluates LLMs through multi-turn conversations, focusing on their ability to engage in coherent, informative, and engaging exchanges. Since human evaluation is very expensive and time consuming, LMSYS uses GPT-4-Turbo to grade the model responses. Their paper shows as 80% agreement between strong LLM and human preferences. The LMSYS leaderboard is updated regularly (last updated February 2, 2024). MT-Bench is part of the FastChat Repository.
MT-Bench supports two different evaluation stratgies:
- single-answer grading: LLM grade and give a score to model's answer directly on a scale of 10
- pair-wise comparison: Compare two models and see which one is better using LLM as judge, resulting in a win-rate.
We are going to use the pair-wise comparison method to compare the base SFT Model with the DPO model, to see if aligning the model with DPO improved the model. Running pairwise comparison on MT-Bench includes the following steps:
- Clone the FastChat Repository & install the requirements
- Generate Responses using our SFT (original) & DPO (trained) model
- Evaluate the responses using pair-wise comparison and GPT-4-Turbo as Judge
- Plot and compare the results
MT-Bench currenlty only support OpenAI or Anthropic as Judge, where GPT-4 is the best. If you don't have access to GPT-4 you need to use a different evaluation method. I forked the FastChat repository and added GPT-4 Turbo reference answers to keep the cost lower.
Note: If you use this example to train different model, e.g. llama you need to make sure that your model is registered and support in FastChat. This means you need:
- a registered conversation template
- a moodel adapter used to match the model path
- register the model adapter
The easiest way to do this is to fork my repository and then add your model. In our example the base model of is cognitivecomputations/dolphin-2.1-mistral-7b
, which is already registered in FastChat.
1. Clone the FastChat Repository & install the requirements
Let's start by cloning the FastChat repository and installing the requirements.
Note: Restart your notebook/kernel to clear up all GPU memory.
2. Generate Responses using our SFT (original) & DPO (trained) model
To Generate the responses in MT-Bench we need our directory into FastChat/fastchat/llm_judge
and then run the gen_model_answer.py
script. This will generate the responses and save them into a file. We will use the default --max-new-token
length of 1024
, which could lead to some truncation. If you want to avoid truncation you can increase the --max-new-token
length to 1512
or higher.
We change into the FastChat/fastchat/llm_judge
directory to run all the evaluation scripts.
Lets start with the SFT model and then the DPO model.
Note: The answer of the models will be stored to FastChat/fastchat/llm_judge/data/mt_bench/model_answer
. You might want to save them later for additional evaluation, when you have a new fine-tuned model.
Note: Generating all responses can take a while, ~60 minutes or more.
Now, we generate the responses using the DPO model.
Note: Generating all responses can take a while, ~120 minutes or more.
3. Evaluate the responses using pair-wise comparison and GPT-4-Turbo as Judge
After we have the responses we can evaluate them using the gen_judgment.py
script. This will pairwise compare all the responses using GPT-4-Turbo and rate which response is better.
Note: We need an OPENAI_API_KEY
with access to GPT-4 Turbo, running MT-Bench will cost ~1-2$ per model evaluation.
Note: This can take ~70 minutes.
4. Plot and compare the results
After we have the results we can plot them and compare the win-rate of the SFT and DPO model.
model | win | loss | tie | win_rate | loss_rate | win_rate_adjusted |
---|---|---|---|---|---|---|
mistral-dolphin-dpo | 45 | 17 | 98 | 0.28125 | 0.10625 | 0.5875 |
mistral-dolphin-sft | 17 | 45 | 98 | 0.10625 | 0.28125 | 0.4125 |
By using DPO we were able to achieve a win-rate of 0.5875 compared to 0.4125 with the SFT model. This means by applying DPO we tuned our model to generate responses, which are more aligned with what humans/AI would prefer. This is not optimal yet, but it's a good start.
Since the guide is only a starting point, you should consider additional evaluation methods, e.g. human evaluation or instruction-following capabilities. This means we might not have reached the full potential of the model. You should consider training for more epochs and on a larger dataset to improve the model further.
5. Clean up the FastChat Repository
Since we temporary cloned the FastChat repository we can now clean it up by deleting the directory.
Note: If you want to keep your evaluation results you should save the model_answer
and judgment
directory.
Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.