How to Fine-Tune Multimodal Models or VLMs with Hugging Face TRL
Multimodal LLMs are making tremendous progress recently. We now have a diverse ecosystem of powerful open Multimodal models, mostly Vision-Language Models (VLM), including Meta AI's Llama-3.2-11B-Vision, Mistral AI's Pixtral-12B, Qwen's Qwen2-VL-7B, and Allen AI's Molmo-7B-D-0924.
These VLMs can handle a variety of multimodal tasks, including image captioning, visual question answering, and image-text matching without additional training. However, to customize a model for your specific application, you may need to fine-tune it on your data to achieve higher quality results or to create a more efficient model for your use case.
This blog post walks you through how to fine-tune open VLMs using Hugging Face TRL, Transformers & datasets in 2024. We'll cover:
- Define our multimodal use case
- Setup development environment
- Create and prepare the multimodal dataset
- Fine-tune VLM using
trl
and theSFTTrainer
- Test and evaluate the VLM
Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs.
1. Define our multimodal use case
When fine-tuning VLMs, it's crucial to clearly define your use case and the multimodal task you want to solve. This will guide your choice of base model and help you create an appropriate dataset for fine-tuning. If you haven't defined your use case yet, you might want to revisit your requirements.
It's worth noting that for most use cases fine-tuning might not be the first option. We recommend evaluating pre-trained models or API-based solutions before committing to fine-tuning your own model.
As an example, we'll use the following multimodal use case:
We want to fine-tune a model that can generate detailed product descriptions based on product images and basic metadata. This model will be integrated into our e-commerce platform to help sellers create more compelling listings. The goal is to reduce the time it takes to create product descriptions and improve their quality and consistency.
Existing models might already be very good for this use case, but you might want to tweak/tune it to your specific needs. This image-to-text generation task is well-suited for fine-tuning VLMs, as it requires understanding visual features and combining them with textual information to produce coherent and relevant descriptions. I created a test dataset for this use case using Gemini 1.5 philschmid/amazon-product-descriptions-vlm.
2. Setup development environment
Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.
We will use the Hugging Face Hub as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the Hugging Face for this. After you have an account, we will use the login
util from the huggingface_hub
package to log into our account and store our token (access key) on the disk.
3. Create and prepare the dataset
Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. We have to prepare the dataset in a format that the model can understand.
In our example we will use philschmid/amazon-product-descriptions-vlm, which contains 1,350 amazon products with title, images and descriptions and metadata. We want to fine-tune our model to generate product descriptions based on the images, title and metadata. Therefore we need to create a prompt including the title, metadata and image and the completion is the description.
TRL supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and trl
will take care of the rest.
In our example we are going to load our dataset using the Datasets library and apply our frompt and convert it into the the conversational format.
Lets start with defining our instruction prompt.
Now, we can format our dataset.
trl
and the SFTTrainer
4. Fine-tune VLM using We are now ready to fine-tune our model. We will use the SFTTrainer from trl
to fine-tune our model. The SFTTrainer
makes it straightfoward to supervise fine-tune open LLMs and VLMs. The SFTTrainer
is a subclass of the Trainer
from the transformers
library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features.
We will use the PEFT features in our example. As peft method we will use QLoRA a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA blog post.
Note: We cannot use Flash Attention as we need to pad our multimodal inputs.
We are going to use Qwen 2 VL 7B model, but we can easily swap out the model for another model, including Meta AI's Llama-3.2-11B-Vision, Mistral AI's Pixtral-12B or any other LLMs by changing our model_id
variable. We will use bitsandbytes to quantize our model to 4-bit.
Note: Be aware the bigger the model the more memory it will require. In our example we will use the 7B version, which can be tuned on 24GB GPUs.
Correctly, preparing the LLM, Tokenizer and Processor for training VLMs is crucial. The Processor is responsible for including the special tokens and image features in the input.
The SFTTrainer
supports a native integration with peft
, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our LoraConfig
and provide it to the trainer. Our LoraConfig
parameters are defined based on the qlora paper and sebastian's blog post.
Before we can start our training we need to define the hyperparameters (SFTConfig
) we want to use and make sure our inputs are correcty provided to the model. Different to text-only supervised fine-tuning we need to provide the image to the model as well. Therefore we create a custom DataCollator
which formates the inputs correctly and include the image features. We use the process_vision_info method from a utility package the Qwen2 team provides. If you are using another model, e.g. Llama 3.2 Vision you might have to check if that creates the same processsed image information.
We now have every building block we need to create our SFTTrainer
to start then training our model.
Start training our model by calling the train()
method on our Trainer
instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.
Training for 3 epochs with a dataset of ~1k samples took 01:31:58 on a g6.2xlarge
. The instance costs 0.9776$/h
which brings us to a total cost of only 1.4$
.
5. Test Model and run Inference
After the training is done we want to evaluate and test our model. First we will load the base model and let it generate a description for a random Amazon product. Then we will load our Q-LoRA adapted model and let it generate a description for the same product.
Finally we can merge the adapter into the base model to make it more efficient and run inference on the same product again.
I selected a random product from Amazon and prepared a generate_description
function to generate a description for the product.
Awesome it is working! Lets load our adapter and compare if with the base model.
Lets compare them side by side using a markdown table.
Base Generation Fine-tuned Generation Introducing the Hasbro Marvel Avengers Series Marvel Assemble Titan-Held Iron Man Action Figure, a 30.5 cm tall action figure that is sure to bring the excitement of the Marvel Universe to life! This highly detailed Iron Man figure is perfect for fans of all ages and makes a great addition to any toy collection. With its sleek red and gold armor, this Iron Man figure is ready to take on any challenge. The Titan-Held feature allows for a more realistic and dynamic pose, making it a must-have for any Marvel fan. Whether you're a collector or just looking for a fun toy to play with, this Iron Man action figure is the perfect choice. Unleash the power of Iron Man with this 30.5 cm Hasbro Marvel Avengers Titan Hero Action Figure! This highly detailed Iron Man figure is perfect for collectors and kids alike. Features a realistic design and articulated joints for dynamic poses. A must-have for any Marvel fan's collection!
Nice! Even though we just had ~1k samples we can see that the fine-tuning improve the product description generation. The description is way shorter and more concise, which fits our training data.
Optional: Merge LoRA adapter in to the original model
When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload
method and then save the model with the save_pretrained
method. This will save a default model, which can be used for inference.
Note: This requires > 30GB CPU Memory.
Bonus: Use TRL example script
TRL provides a simple example script to fine-tune multimodal models. You can find the script here. The script can be directly run from the command line and supports all the features of the SFTTrainer
.