Extended Guide: Instruction-tune Llama 2
This blog post is an extended guide on instruction-tuning Llama 2 from Meta AI. The idea of the blog post is to focus on creating the instruction dataset, which we can then use to fine-tune the base model of Llama 2 to follow our instructions.
The goal is to create a model which can create instructions based on input. The idea behind this is that this can then be used for others to create instruction data from inputs. That's especially helpful if you want to personalize models for, e.g., tweeting, email writing, etc, which means that you would be able to generate an instruction dataset from your emails to then train a model to mimic your email writing.
Okay, so can we get started on this? In the blog, we are going to:
- Define the use case and create a prompt template for instructions
- Create an instruction dataset
- Instruction-tune Llama 2 using
trl
and theSFTTrainer
- Test the Model and run Inference
Note: This tutorial was created and run on a g5.2xlarge AWS EC2 Instance, including an NVIDIA A10G GPU.
1. Define the use case and create a prompt template for instructions
Before we describe our use case, we need to better understand what even is an instruction.
An instruction is a piece of text or prompt that is provided to an LLM, like Llama, GPT-4, or Claude, to guide it to generate a response. Instructions allow humans to steer the conversation and constrain the language model's output to be more natural, useful, and aligned with the user's goals. Crafting clear, well-formulated instructions is key to productive conversations.
Examples of instructions are listed below in the table.
Capability | Example Instruction |
---|---|
Brainstorming | Provide a diverse set of creative ideas for new flavors of ice cream. |
Classification | Categorize these movies as either comedy, drama, or horror based on the plot summary. |
Closed QA | Answer the question 'What is the capital of France?' with a single word. |
Generation | Write a poem in the style of Robert Frost about nature and the changing seasons. |
Information Extraction | Extract the names of the main characters from this short story. |
Open QA | Why do leaves change color in autumn? Explain the scientific reasons. |
Summarization | Summarize this article on recent advancements in renewable energy in 2-3 sentences. |
As described in the beginning, we want to fine-tune a model to be able to generate instructions based on input. (output). We want to use this as a way to create synthetic datasets to personalize LLMs and Agents.
Converting the idea into a basic prompt template following the Alpaca format we get.
2. Create an instruction dataset
After we defined our use case and prompt template, we need to create our instruction dataset. Creating a high-quality instruction dataset is key for a good-performing model. Research shows that “Less Is More for Alignment” shows that creating a high-quality, low-quantity (~1000 samples) dataset can achieve the same performance as less-quality and high-quantity datasets.
There are several ways to create an instruction dataset, including:
- Using an existing dataset and converting it into an instruction dataset, e.g., FLAN
- Use existing LLMs to create synthetically instruction datasets, e.g., Alpaca
- Use Humans to create instructions datasets, e.g., Dolly.
Each of the methods has its own advantages and disadvantages and depends on the budget, time, and quality requirements. For example, using an existing dataset is the easiest but might not be tailored to your specific use case, while using humans might be the most accurate but can be time-consuming and expensive. It is also possible to combine several methods to create an instruction dataset, as shown in Orca: Progressive Learning from Complex Explanation Traces of GPT-4.
To keep it simple, we are going to use Dolly an open-source dataset of instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the InstructGPT paper, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.
Let's start coding, but first, let's install our dependencies.
To load the databricks/databricks-dolly-15k
dataset, we use the load_dataset()
method from the 🤗 Datasets library.
To instruct tune our model, we need to convert our structured examples into a collection of tasks described via instructions. We define a formatting_function
that takes a sample and returns a string with our format instruction.
Let's test our formatting function on a random example.
trl
and the SFTTrainer
3. Instruction-tune Llama 2 using We will use the recently introduced method in the paper "QLoRA: Quantization-aware Low-Rank Adapter Tuning for Language Generation" by Tim Dettmers et al. QLoRA is a new technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance. The TL;DR; of how QLoRA works is:
- Quantize the pre-trained model to 4 bits and freeze it.
- Attach small, trainable adapter layers. (LoRA)
- Finetune only the adapter layers while using the frozen quantized model for context.
If you want to learn more about QLoRA and how it works, I recommend you to read the Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA blog post.
Flash Attention
Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. It is based on the paper "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness". The TL;DR; accelerates training up to 3x. Learn more at FlashAttention. Flash Attention is currently only available for Ampere (A10, A40, A100, ...) & Hopper (H100, ...) GPUs. You can check if your GPU is supported and install it using the following command:
Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of MAX_JOBS
. On the g5.2xlarge
we used 4
.
Installing flash attention can take quite a bit of time (10-45 minutes).
The example supports the use of Flash Attention for all Llama checkpoints, but is not enabled by default. To use Flash Attention change the value of use_flash_attentin
to True
The SFTTrainer
supports a native integration with peft
, which makes it super easy to efficiently instruction tune LLMs. We only need to create our LoRAConfig
and provide it to the trainer.
Before we can start our training we need to define the hyperparameters (TrainingArguments
) we want to use.
We now have every building block we need to create our SFTTrainer
to start then training our model.
Start training our model by calling the train()
method on our Trainer
instance.
The training without Flash Attention enabled took 03:08:00 on a g5.2xlarge
. The instance costs 1,212$/h
which brings us to a total cost of 3.7$
.
The training with Flash Attention enabled took 02:08:00 on a g5.2xlarge
. The instance costs 1,212$/h
which brings us to a total cost of 2.6$
.
The results using Flash Attention are mind blowing and impressive, 1.5x faster and 30% cheaper.
4. Test Model and run Inference
After the training is done we want to run and test our model. We will use peft
and transformers
to load our LoRA adapter into our model.
Let’s load the dataset again with a random sample to try to generate an instruction.
result
Prompt:
Jack Dorsey, Noah Glass, Biz Stone, Evan Williams
Generated instruction:
Extract the founders of Twitter from the passage. Display the results in a comma separated format.
Ground truth:
List the founders of Twitter from the above passage in a comma separated format.
Nice! our model works! If want to accelerate our model we can deploy it with Text Generation Inference. Therefore we would need to merge our adapter weights into the base model.
Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.