philschmid

Fine-tune Falcon 180B with DeepSpeed ZeRO, LoRA & Flash Attention

Published on
11 min read
View Code

Falcon 180B is the newest version of Falcon LLM family. It is the biggest open source model with 180B parameter and trained on more data - 3.5T tokens with context length window upto 4K tokens. In this example we will show how to fine-tune Falcon 180B using DeepSpeed, Hugging Face Transformers, LoRA with Flash Attention on a multi-GPU machine.

In detail you will learn how to:

  1. Setup Development Environment
  2. Load and prepare the dataset
  3. Fine-Tune Falcon 180B using DeepSpeed, Hugging Face Transformers, LoRA with Flash Attention

Before we get into the code lets take a quick look on the technologies and methods we are going to use:

What is DeepSpeed ZeRO?

DeepSpeed ZeRO focuses on efficient large-scale training of Transformers. ZeRO, or Zero Redundancy Optimizer, reduces memory footprint by partitioning model states across devices instead of basic data parallelism. This saves significant memory - ZeRO-Infinity can reduce usage 100x vs data parallelism. ZeRO-Offload further reduces memory by offloading parts of model and optimizer to CPU, enabling 10B+ parameter models on 1 GPU. ZeRO integrates with HuggingFace Transformers through a configuration file.

What is LoRA?

LoRA enables efficient fine-tuning of large language models. It decomposes weight matrices into smaller, trainable update matrices that adapt while keeping original weights frozen. This drastically reduces trainable parameters for faster, lower-memory tuning. LoRA integrates into Transformers via Hugging Face's PEFT. It combines well with methods like DeepSpeed. Key advantages are efficient tuning, portable models, and no inference latency when merging trained weights. LoRA allows adaptively training massive models with limited resources.

What is Flash Attention?

Flash Attention is an algorithm that speeds up the core attention mechanism in Transformer language models by restructuring computations. It uses techniques like tiling and recomputation to reduce the high memory costs of attention, enabling models to process longer text sequences. Flash Attention 2 optimizes parallelism and work partitioning for 2x speedup over the previous version, reaching 230 TFLOPS/s on A100 GPUs.

Access Falcon 180B

Before we can start training we have to make sure that we accepted the license tiiuae/falcon-180B to be able to use it. You can accept the license by clicking on the Agree and access repository button on the model page at:

The example was created and run a DGX A100 8-GPU machine with 80GB GPU memory per GPU.

1. Setup Development Environment

conda create --name hf python=3.10 -c conda-forge

# install torch with the correct cuda version, check nvcc --version
!pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --upgrade
# install Hugging Face Libraries and additional dependencies
!pip install "transformers==4.33.1" "datasets==2.14.5" "accelerate==0.22.0" "evaluate==0.4.0" "peft==0.5.0" tensorboard packaging --upgrade
# install deepspeed and ninja for jit compilations of kernels
!pip install "deepspeed==0.10.3" ninja --upgrade
# install additional Flash Attention
!pip install flash-attn --no-build-isolation --upgrade

To access any Falcon 180B asset we need to login into our hugging face account. We can do this by running the following command:

!huggingface-cli login --token YOUR_TOKEN

2. Load and prepare the dataset

we will use the dolly an open source dataset of instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the InstructGPT paper, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.

{
  "instruction": "What is world of warcraft",
  "context": "",
  "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment"
}

To load the samsum dataset, we use the load_dataset() method from the 🤗 Datasets library.

from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# dataset size: 15011

To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a formatting_function that takes a sample and returns a string with our format instruction.

def format_dolly(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt

lets test our formatting function on a random example.

from random import randrange

print(format_dolly(dataset[randrange(len(dataset))]))

In addition, to formatting our samples we also want to pack multiple samples to one sequence to have a more efficient training.

from transformers import AutoTokenizer

model_id = "tiiuae/falcon-180B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

We define some helper functions to pack our samples into sequences of a given length and then tokenize them.

from random import randint
from itertools import chain
from functools import partial


# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
    return sample


# apply prompt template per sample
dataset = dataset.map(template_dataset, remove_columns=list(dataset.features))
# print random sample
print(dataset[randint(0, len(dataset))]["text"])

# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])

    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length

    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result


# tokenize and chunk dataset
lm_dataset = dataset.map(
    lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
).map(
    partial(chunk, chunk_length=2048),
    batched=True,
)

# Print total number of samples
print(f"Total number of samples: {len(lm_dataset)}")

After we processed the datasets we want to save it to disk to be able to use the processed dataset later during training.

lm_dataset.save_to_disk("dolly-processed")

3. Fine-Tune Falcon 180B using DeepSpeed, Hugging Face Transformers, LoRA with Flash Attention

DeepSpeed ZeRO is natively integrated into the Hugging Face Transformers Trainer. The integration enables leveraging ZeRO by simply providing a DeepSpeed config file, and the Trainer takes care of the rest. We created 2 deepspeed configurations for the experiments we ran, including CPU offloading:

As mentioned in the beginning, we ran those example using a 8x NVIDIA A100 80GB. This means we can leverage bf16, which reduces the memory footprint of the model by almost ~2x, which allows us to train without offloading efficiently. We are going to use the ds_falcon_180b_z3.json. If you are irritated by the auto values, check the documentation.

In addition to the deepspeed configuration we also need a training script, which implements LoRA and patches our model to use flash-attention. We created a run_ds_lora.py script, which patches the falcon model using the falcon_patch.py utils and implements LoRA using peft_utils.py.

When you run make sure that you have the same folder structure and utils/configs available. The easiest way is to clone the whole repository. Go into the training directory and start the training.

Once we made sure that we have the right configuration and training script we can start the training using torchrun.

!torchrun --nproc_per_node 8 run_ds_lora.py \
  --model_id tiiuae/falcon-180B \
  --dataset_path dolly-processed \
  --output_dir falcon-180b-lora-fa \
  --num_train_epochs 3 \
  --per_device_train_batch_size 1 \
  --learning_rate 4e-3 \
  --gradient_checkpointing True \
  --gradient_accumulation_steps 8 \
  --bf16 True \
  --tf32 True \
  --use_flash_attn True \
  --lr_scheduler_type "constant_with_warmup" \
  --logging_steps 25 \
  --save_steps 100 \
  --save_total_limit 3 \
  --deepspeed configs/ds_falcon_180b_z3.json

Note: Since we are using LoRA we are only saving the "trained" adapter weights, to save some storage. If you want to merge the adapters back into the base model and save the merged model you can add --merge_adapters True or use the merge_adapter_weights.py script.

In our example for Falcon 180B, the training time was 153 minutes or ~2 hours for 3 epochs. For comparison the pretraining cost of Falcon 180B was ~7,000,000 GPU hours, which is 3,500,000 time more than fine-tuning.

Conclusion

In the blog post you learn how to fine-tune Falcon 180B model using DeepSpeed, Hugging Face Transformers, and LoRA with Flash Attention on a multi-GPU machine. We used:

  • DeepSpeed ZeRO for memory optimization, enabling training models with up to trillions of parameters on limited GPU memory. We used stage 3 (ZeRO-Infinity) to optimize memory usage.
  • Hugging Face Transformers and Datasets for easily loading and preparing the text dataset as well as providing an intuitive Trainer API.
  • LoRA, a method to efficiently fine-tune large language models by only updating a small percentage of parameters each iteration. This drastically reduces memory usage and computational costs.
  • Flash Attention - a highly optimized attention implementation that further reduces the memory footprint.

Compining all of those methods allows us to fine-tune LLMs with over 100B+ parameter with limited resources. The example provides a template for efficiently tuning the largest publicly available models.


Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.