How to fine-tune Google Gemma with ChatML and Hugging Face TRL
Last week, Google released Gemma, a new family of state-of-the-art open LLMs. Gemma comes in two sizes: 7B parameters, for efficient deployment and development on consumer-size GPU and TPU and 2B versions for CPU and on-device applications. Both come in base and instruction-tuned variants.
After the first week it seemed that Gemma is not very friendly to fine-tune using the ChatML format, which is adapted and used by the open soruce community, e.g. OpenHermes or Dolphin. I created this blog post to show you how to fine-tune Gemma using ChatML and Hugging Face TRL.
This blog post is derived from my How to Fine-Tune LLMs in 2024 with Hugging Face blog tailored to fine-tune Gemma 7B. We will use Hugging Face TRL, Transformers & datasets.
- Setup development environment
- Create and prepare the dataset
- Fine-tune LLM using
trl
and theSFTTrainer
- Test and evaluate the LLM
Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs.
1. Setup development environment
Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.
If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at FlashAttention.
Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of MAX_JOBS
. On the g5.2xlarge
we used 4
.
Installing flash attention from source can take quite a bit of time (10-45 minutes).
We will also need to login into our Hugging Face account to be able to access Gemma. To use Gemma you first need to agree to the terms of use. You can do this by visiting the Gemma page following the gate mechanism.
2. Create and prepare the dataset
We are not going to focus on creating a dataset in this blog post. If you want to learn more about creating a dataset, I recommend reading the How to Fine-Tune LLMs in 2024 with Hugging Face blog post. We are going to use the Databricks Dolly datatset, formated already as messages. This means we can use the conversational
format to fine-tune our model.
The latest release of trl
supports the conversation dataset formats. This means don't need to do any additional formatting of the dataset. We can use the dataset as is.
trl
and the SFTTrainer
3. Fine-tune LLM using We will use the SFTTrainer from trl
to fine-tune our model. The SFTTrainer
makes it straightfoward to supervise fine-tune open LLMs. The SFTTrainer
is a subclass of the Trainer
from the transformers
library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:
- Dataset formatting, including conversational and instruction format
- Training on completions only, ignoring prompts
- Packing datasets for more efficient training
- PEFT (parameter-efficient fine-tuning) support including Q-LoRA
- Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)
We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use QLoRA a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.
Note: Gemma comes with a big vocabulary of ~250,000 tokens. Normally if you want to fine-tune LLMs on the ChatML format you would need to add special tokens to the tokenizer and model and teach to understand the different roles in a conversation. But Google included ~100 placeholder tokens in the vocabulary, which we can replace with special tokens, like <|im_start|>
and <|im_end|>
. I created a Tokenizer for the ChatML format philschmid/gemma-tokenizer-chatml which you can use to fine-tune Gemma with ChatML.
The Chat template used during fine-tuning is not 100% compatible with the ChatML format. Since Google/gemma-7b requires inputs always to start with a <bos>
token. This means our inputs will look like.
Note: We are not having an idea why Gemma needs <bos>
token at the beginning of the input.
The SFTTrainer
supports a native integration with peft
, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our LoraConfig
and provide it to the trainer.
Before we can start our training we need to define the hyperparameters (TrainingArguments
) we want to use.
We now have every building block we need to create our SFTTrainer
to start then training our model.
Start training our model by calling the train()
method on our Trainer
instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.
The training with Flash Attention for 3 epochs with a dataset of 15k samples took 4:14:36 on a g5.2xlarge
. The instance costs 1.21$/h
which brings us to a total cost of only ~5.3$
.
Optional: Merge LoRA adapter in to the original model
When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload
method and then save the model with the save_pretrained method.
Check out the How to Fine-Tune LLMs in 2024 with Hugging Face blog post on how to do it .
3. Test Model and run Inference
After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric.
Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out Evaluate LLMs and RAG a practical example using Langchain and Hugging Face blog post.
We load the adapted model and the tokenize into the pipeline
to easily test it and extract the token id of <|im_end|>
to use it in the generate
method.
Lets test some prompt samples and see how the model performs.
Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.