How to fine-tune Google Gemma with ChatML and Hugging Face TRL

March 1, 202410 minute readView Code

Last week, Google released Gemma, a new family of state-of-the-art open LLMs. Gemma comes in two sizes: 7B parameters, for efficient deployment and development on consumer-size GPU and TPU and 2B versions for CPU and on-device applications. Both come in base and instruction-tuned variants.

After the first week it seemed that Gemma is not very friendly to fine-tune using the ChatML format, which is adapted and used by the open soruce community, e.g. OpenHermes or Dolphin. I created this blog post to show you how to fine-tune Gemma using ChatML and Hugging Face TRL.

This blog post is derived from my How to Fine-Tune LLMs in 2024 with Hugging Face blog tailored to fine-tune Gemma 7B. We will use Hugging Face TRL, Transformers & datasets.

  1. Setup development environment
  2. Create and prepare the dataset
  3. Fine-tune LLM using trl and the SFTTrainer
  4. Test and evaluate the LLM

Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs.

1. Setup development environment

Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs.

# Install Pytorch & other libraries
!pip install "torch==2.1.2" tensorboard
# Install Hugging Face libraries
!pip install  --upgrade \
  "transformers==4.38.2" \
  "datasets==2.16.1" \
  "accelerate==0.26.1" \
  "evaluate==0.4.1" \
  "bitsandbytes==0.42.0" \
  "trl==0.7.11" \

If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at FlashAttention.

Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of MAX_JOBS. On the g5.2xlarge we used 4.

import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
# install flash-attn
!pip install ninja packaging
!MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade

Installing flash attention from source can take quite a bit of time (10-45 minutes).

We will also need to login into our Hugging Face account to be able to access Gemma. To use Gemma you first need to agree to the terms of use. You can do this by visiting the Gemma page following the gate mechanism.

from huggingface_hub import login
  token="", # ADD YOUR TOKEN HERE

2. Create and prepare the dataset

We are not going to focus on creating a dataset in this blog post. If you want to learn more about creating a dataset, I recommend reading the How to Fine-Tune LLMs in 2024 with Hugging Face blog post. We are going to use the Databricks Dolly datatset, formated already as messages. This means we can use the conversational format to fine-tune our model.

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

The latest release of trl supports the conversation dataset formats. This means don't need to do any additional formatting of the dataset. We can use the dataset as is.

from datasets import load_dataset
# Load Dolly Dataset.
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

3. Fine-tune LLM using trl and the SFTTrainer

We will use the SFTTrainer from trl to fine-tune our model. The SFTTrainer makes it straightfoward to supervise fine-tune open LLMs. The SFTTrainer is a subclass of the Trainer from the transformers library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:

  • Dataset formatting, including conversational and instruction format
  • Training on completions only, ignoring prompts
  • Packing datasets for more efficient training
  • PEFT (parameter-efficient fine-tuning) support including Q-LoRA
  • Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use QLoRA a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.

Note: Gemma comes with a big vocabulary of ~250,000 tokens. Normally if you want to fine-tune LLMs on the ChatML format you would need to add special tokens to the tokenizer and model and teach to understand the different roles in a conversation. But Google included ~100 placeholder tokens in the vocabulary, which we can replace with special tokens, like <|im_start|> and <|im_end|>. I created a Tokenizer for the ChatML format philschmid/gemma-tokenizer-chatml which you can use to fine-tune Gemma with ChatML.

The Chat template used during fine-tuning is not 100% compatible with the ChatML format. Since Google/gemma-7b requires inputs always to start with a <bos> token. This means our inputs will look like.

You are Gemma.<|im_end|>
Hello, how are you?<|im_end|>
I'm doing great. How can I help you today?<|im_end|>\n<eos>

Note: We are not having an idea why Gemma needs <bos> token at the beginning of the input.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-7b"
tokenizer_id = "philschmid/gemma-tokenizer-chatml"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right' # to prevent warnings

The SFTTrainer  supports a native integration with peft, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our LoraConfig and provide it to the trainer.

from peft import LoraConfig
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(

Before we can start our training we need to define the hyperparameters (TrainingArguments) we want to use.

from transformers import TrainingArguments
args = TrainingArguments(
    output_dir="gemma-7b-dolly-chatml", # directory to save and repository id
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=2,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=False,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard

We now have every building block we need to create our SFTTrainer to start then training our model.

from trl import SFTTrainer
max_seq_length = 1512 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token

Start training our model by calling the train() method on our Trainer instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model.

# start training, the model will be automatically saved to the hub and the output directory
# save model

The training with Flash Attention for 3 epochs with a dataset of 15k samples took 4:14:36 on a g5.2xlarge. The instance costs 1.21$/h which brings us to a total cost of only ~5.3$.

Optional: Merge LoRA adapter in to the original model

When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the merge_and_unload method and then save the model with the save_pretrained method.

Check out the How to Fine-Tune LLMs in 2024 with Hugging Face blog post on how to do it .

3. Test Model and run Inference

After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric.

Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out Evaluate LLMs and RAG a practical example using Langchain and Hugging Face blog post.

# free the memory again
del model
del trainer

We load the adapted model and the tokenize into the pipeline to easily test it and extract the token id of <|im_end|> to use it in the generate method.

import torch
from peft import AutoPeftModelForCausalLM
from transformers import  AutoTokenizer, pipeline
peft_model_id = "gemma-7b-dolly-chatml"
# Load Model with PEFT adapter
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# get token id for end of conversation
eos_token = tokenizer("<|im_end|>",add_special_tokens=False)["input_ids"][0]

Lets test some prompt samples and see how the model performs.

prompts = [
    "What is the capital of Germany? Explain why thats the case and if it was different in the past?",
    "Write a Python function to calculate the factorial of a number.",
    "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?",
    "What is the difference between a fruit and a vegetable? Give examples of each.",
def test_inference(prompt):
    prompt = pipe.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=1024, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=eos_token)
    return outputs[0]['generated_text'][len(prompt):].strip()
for prompt in prompts:
    print(f"    prompt:\n{prompt}")
    print(f"    response:\n{test_inference(prompt)}")

Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.