philschmid blog

Task-specific knowledge distillation for BERT using Transformers & Amazon SageMaker

#HuggingFace #AWS #BERT #PyTorch
, February 01, 2022 · 11 min read

Photo by Paul Byrne on Unsplash

Welcome to this end-to-end task-specific knowledge distillation Text-Classification example using Transformers, PyTorch & Amazon SageMaker. Distillation is the process of training a small “student” to mimic a larger “teacher”. In this example, we will use a BERT-base as Teacher and BERT-Tiny as Student. We will use Text-Classification as our task-specific knowledge distillation task and the Stanford Sentiment Treebank v2 (SST-2) dataset for training.

They are two different types of knowledge distillation, the Task-agnostic knowledge distillation (right) and the Task-specific knowledge distillation (left). In this example we are going to use the Task-specific knowledge distillation.

knowledge-distillation Task-specific distillation (left) versus task-agnostic distillation (right). Figure from FastFormers by Y. Kim and H. Awadalla [arXiv:2010.13382].

In Task-specific knowledge distillation a “second step of distillation” is used to “fine-tune” the model on a given dataset. This idea comes from the DistilBERT paper where it was shown that a student performed better than simply finetuning the distilled language model:

We also studied whether we could add another step of distillation during the adaptation phase by fine-tuning DistilBERT on SQuAD using a BERT model previously fine-tuned on SQuAD as a teacher for an additional term in the loss (knowledge distillation). In this setting, there are thus two successive steps of distillation, one during the pre-training phase and one during the adaptation phase. In this case, we were able to reach interesting performances given the size of the model:79.8 F1 and 70.4 EM, i.e. within 3 points of the full model.

If you are more interested in those topics you should defintely read:

Especially the FastFormers paper contains great research on what works and doesn’t work when using knowledge distillation.

Huge thanks to Lewis Tunstall and his great Weeknotes: Distilling distilled transformers


1 #%pip install "pytorch==1.10.1"
2 %pip install transformers datasets tensorboard --upgrade
4 !sudo apt-get install git-lfs

This example will use the Hugging Face Hub as remote model versioning service. To be able to push our model to the Hub, you need to register on the Hugging Face. If you already have an account you can skip this step. After you have an account, we will use the notebook_login util from the huggingface_hub package to log into our account and store our token (access key) on the disk.

1 from huggingface_hub import notebook_login
3 notebook_login()

Setup & Configuration

In this step we will define global configurations and paramters, which are used across the whole end-to-end fine-tuning proccess, e.g. teacher and studen we will use.

In this example, we will use BERT-base as Teacher and BERT-Tiny as Student. Our Teacher is already fine-tuned on our dataset, which makes it easy for us to directly start the distillation training job rather than fine-tuning the teacher first to then distill it afterwards.

IMPORTANT: This example will only work with a Teacher & Student combination where the Tokenizer is creating the same output.

Additionally, describes the FastFormers: Highly Efficient Transformer Models for Natural Language Understanding paper an additional phenomenon.

In our experiments, we have observed that dis- tilled models do not work well when distilled to a different model type. Therefore, we restricted our setup to avoid distilling RoBERTa model to BERT or vice versa. The major difference between the two model groups is the input token (sub-word) em- bedding. We think that different input embedding spaces result in different output embedding spaces, and knowledge transfer with different spaces does not work well

1 student_id = "google/bert_uncased_L-2_H-128_A-2"
2 teacher_id = "textattack/bert-base-uncased-SST-2"
4 # name for our repository on the hub
5 repo_name = "tiny-bert-sst2-distilled"

Below are some checks to make sure the Teacher & Student are creating the same output.

1 from transformers import AutoTokenizer
3 # init tokenizer
4 teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id)
5 student_tokenizer = AutoTokenizer.from_pretrained(student_id)
7 # sample input
8 sample = "This is a basic example, with different words to test."
10 # assert results
11 assert teacher_tokenizer(sample) == student_tokenizer(sample), "Tokenizers haven't created the same output"

Dataset & Pre-processing

As Dataset we will use the Stanford Sentiment Treebank v2 (SST-2) a text-classification for sentiment-analysis, which is included in the GLUE benchmark. The dataset is based on the dataset introduced by Pang and Lee (2005) and consists of 11,855 single sentences extracted from movie reviews. It was parsed with the Stanford parser and includes a total of 215,154 unique phrases from those parse trees, each annotated by 3 human judges. It uses the two-way (positive/negative) class split, with only sentence-level labels.

1 dataset_id="glue"
2 dataset_config="sst2"

To load the sst2 dataset, we use the load_dataset() method from the 🤗 Datasets library.

1 from datasets import load_dataset
3 dataset = load_dataset(dataset_id,dataset_config)

Pre-processing & Tokenization

To distill our model we need to convert our “Natural Language” to token IDs. This is done by a 🤗 Transformers Tokenizer which will tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary). If you are not sure what this means check out chapter 6 of the Hugging Face Course.

We are going to use the tokenizer of the Teacher, but since both are creating same output you could also go with the Student tokenizer.

1 from transformers import AutoTokenizer
3 tokenizer = AutoTokenizer.from_pretrained(teacher_id)

Additionally we add the truncation=True and max_length=512 to align the length and truncate texts that are bigger than the maximum size allowed by the model.

1 def process(examples):
2 tokenized_inputs = tokenizer(
3 examples["sentence"], truncation=True, max_length=512
4 )
5 return tokenized_inputs
7 tokenized_datasets =, batched=True)
8 tokenized_datasets = tokenized_datasets.rename_column("label","labels")
10 tokenized_datasets["test"].features

Distilling the model using PyTorch and DistillationTrainer

Now that our dataset is processed, we can distill it. Normally, when fine-tuning a transformer model using PyTorch you should go with the Trainer-API. The Trainer class provides an API for feature-complete training in PyTorch for most standard use cases.

In our example we cannot use the Trainer out-of-the-box, since we need to pass in two models, the Teacher and the Student and compute the loss for both. But we can subclass the Trainer to create a DistillationTrainer which will take care of it and only overwrite the compute_loss method as well as the init method. In addition to this we also need to subclass the TrainingArguments to include the our distillation hyperparameters.

1 from transformers import TrainingArguments, Trainer
2 import torch
3 import torch.nn as nn
4 import torch.nn.functional as F
6 class DistillationTrainingArguments(TrainingArguments):
7 def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
8 super().__init__(*args, **kwargs)
10 self.alpha = alpha
11 self.temperature = temperature
13 class DistillationTrainer(Trainer):
14 def __init__(self, *args, teacher_model=None, **kwargs):
15 super().__init__(*args, **kwargs)
16 self.teacher = teacher_model
17 # place teacher on same device as student
18 self._move_model_to_device(self.teacher,self.model.device)
19 self.teacher.eval()
21 def compute_loss(self, model, inputs, return_outputs=False):
23 # compute student output
24 outputs_student = model(**inputs)
25 student_loss=outputs_student.loss
26 # compute teacher output
27 with torch.no_grad():
28 outputs_teacher = self.teacher(**inputs)
30 # assert size
31 assert outputs_student.logits.size() == outputs_teacher.logits.size()
33 # Soften probabilities and compute distillation loss
34 loss_function = nn.KLDivLoss(reduction="batchmean")
35 loss_logits = (loss_function(
36 F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
37 F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
38 # Return weighted student loss
39 loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
40 return (loss, outputs_student) if return_outputs else loss

Hyperparameter Definition, Model Loading

1 from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding
2 from huggingface_hub import HfFolder
4 # create label2id, id2label dicts for nice outputs for the model
5 labels = tokenized_datasets["train"].features["labels"].names
6 num_labels = len(labels)
7 label2id, id2label = dict(), dict()
8 for i, label in enumerate(labels):
9 label2id[label] = str(i)
10 id2label[str(i)] = label
12 # define training args
13 training_args = DistillationTrainingArguments(
14 output_dir=repo_name,
15 num_train_epochs=7,
16 per_device_train_batch_size=128,
17 per_device_eval_batch_size=128,
18 fp16=True,
19 learning_rate=6e-5,
20 seed=33,
21 # logging & evaluation strategies
22 logging_dir=f"{repo_name}/logs",
23 logging_strategy="epoch", # to get more information to TB
24 evaluation_strategy="epoch",
25 save_strategy="epoch",
26 save_total_limit=2,
27 load_best_model_at_end=True,
28 metric_for_best_model="accuracy",
29 report_to="tensorboard",
30 # push to hub parameters
31 push_to_hub=True,
32 hub_strategy="every_save",
33 hub_model_id=repo_name,
34 hub_token=HfFolder.get_token(),
35 # distilation parameters
36 alpha=0.5,
37 temperature=4.0
38 )
40 # define data_collator
41 data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
43 # define model
44 teacher_model = AutoModelForSequenceClassification.from_pretrained(
45 teacher_id,
46 num_labels=num_labels,
47 id2label=id2label,
48 label2id=label2id,
49 )
51 # define student model
52 student_model = AutoModelForSequenceClassification.from_pretrained(
53 student_id,
54 num_labels=num_labels,
55 id2label=id2label,
56 label2id=label2id,
57 )

Evaluation metric

we can create a compute_metrics function to evaluate our model on the test set. This function will be used during the training process to compute the accuracy & f1 of our model.

1 from datasets import load_metric
2 import numpy as np
4 # define metrics and metrics function
5 accuracy_metric = load_metric( "accuracy")
7 def compute_metrics(eval_pred):
8 predictions, labels = eval_pred
9 predictions = np.argmax(predictions, axis=1)
10 acc = accuracy_metric.compute(predictions=predictions, references=labels)
11 return {
12 "accuracy": acc["accuracy"],
13 }


Start training with calling trainer.train()

1 trainer = DistillationTrainer(
2 student_model,
3 training_args,
4 teacher_model=teacher_model,
5 train_dataset=tokenized_datasets["train"],
6 eval_dataset=tokenized_datasets["validation"],
7 data_collator=data_collator,
8 tokenizer=tokenizer,
9 compute_metrics=compute_metrics,
10 )

start training using the DistillationTrainer.

1 trainer.train()

Hyperparameter Search for Distillation parameter alpha & temperature with optuna

The parameter alpha & temparature in the DistillationTrainer can also be used when doing Hyperparamter search to maxizime our “knowledge extraction”. As Hyperparamter Optimization framework are we using Optuna, which has a integration into the Trainer-API. Since we the DistillationTrainer is a sublcass of the Trainer we can use the hyperparameter_search without any code changes.

1 #%pip install optuna

To do Hyperparameter Optimization using optuna we need to define our hyperparameter space. In this example we are trying to optimize/maximize the num_train_epochs, learning_rate, alpha & temperature for our student_model.

1 def hp_space(trial):
2 return {
3 "num_train_epochs": trial.suggest_int("num_train_epochs", 2, 10),
4 "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-3 ,log=True),
5 "alpha": trial.suggest_float("alpha", 0, 1),
6 "temperature": trial.suggest_int("temperature", 2, 30),
7 }

To start our Hyperparmeter search we just need to call hyperparameter_search provide our hp_space and number of trials to run.

1 def student_init():
2 return AutoModelForSequenceClassification.from_pretrained(
3 student_id,
4 num_labels=num_labels,
5 id2label=id2label,
6 label2id=label2id
7 )
9 trainer = DistillationTrainer(
10 model_init=student_init,
11 args=training_args,
12 teacher_model=teacher_model,
13 train_dataset=tokenized_datasets["train"],
14 eval_dataset=tokenized_datasets["validation"],
15 data_collator=data_collator,
16 tokenizer=tokenizer,
17 compute_metrics=compute_metrics,
18 )
19 best_run = trainer.hyperparameter_search(
20 n_trials=50,
21 direction="maximize",
22 hp_space=hp_space
23 )
25 print(best_run)

Since optuna is just finding the best hyperparameters we need to fine-tune our model again using the best hyperparamters from the best_run.

1 # overwrite initial hyperparameters with from the best_run
2 for k,v in best_run.hyperparameters.items():
3 setattr(training_args, k, v)
5 # Define a new repository to store our distilled model
6 best_model_ckpt = "tiny-bert-best"
7 training_args.output_dir = best_model_ckpt

We have overwritten the default Hyperparameters with the one from our best_run and can start the training now.

1 # Create a new Trainer with optimal parameters
2 optimal_trainer = DistillationTrainer(
3 student_model,
4 training_args,
5 teacher_model=teacher_model,
6 train_dataset=tokenized_datasets["train"],
7 eval_dataset=tokenized_datasets["validation"],
8 data_collator=data_collator,
9 tokenizer=tokenizer,
10 compute_metrics=compute_metrics,
11 )
13 optimal_trainer.train()
16 # save best model, metrics and create model card
17 trainer.create_model_card(model_name=training_args.hub_model_id)
18 trainer.push_to_hub()
1 from huggingface_hub import HfApi
3 whoami = HfApi().whoami()
4 username = whoami['name']
6 print(f"{username}/{repo_name}")

Results & Conclusion

We were able to achieve a accuracy of 0.8337, which is a very good result for our model. Our distilled Tiny-Bert has 96% less parameters than the teacher bert-base and runs ~46.5x faster while preserving over 90% of BERT’s performances as measured on the SST2 dataset.


Note: The FastFormers paper uncovered that the biggest boost in performance is observerd when having 6 or more layers in the student. The google/bert_uncased_L-2_H-128_A-2 we used only had 2, which means when changing our student to, e.g. distilbert-base-uncased we should better performance in terms of accuracy.

If you are now planning to implement and add task-specific knowledge distillation to your models. I suggest to take a look at the sagemaker-distillation, which shows how to run task-specific knowledge distillation on Amazon SageMaker. For the example i created a script deriving this notebook to make it as easy as possible to use for you. You only need to define your teacher_id, student_id as well as your dataset config to run task-specific knowledge distillation for text-classification.

1 from sagemaker.huggingface import HuggingFace
3 # hyperparameters, which are passed into the training job
4 hyperparameters={
5 'teacher_id':'textattack/bert-base-uncased-SST-2',
6 'student_id':'google/bert_uncased_L-2_H-128_A-2',
7 'dataset_id':'glue',
8 'dataset_config':'sst2',
9 # distillation parameter
10 'alpha': 0.5,
11 'temparature': 4,
12 # hpo parameter
13 "run_hpo": True,
14 "n_trials": 100,
15 }
17 # create the Estimator
18 huggingface_estimator = HuggingFace(..., hyperparameters=hyperparameters)
20 # start knwonledge distillation training

In conclusion you can say that it is just incredible how easy Transformers and the Trainer API can be used to implement task-specific knowledge distillation. We needed to write ~20 lines of custom code deriving the Trainer into a DistillationTrainer to support task-specific knowledge distillation with leveraging all benefits of the Trainer API like evaluation, hyperparameter tuning, and model card creation.

In addition, we used Amazon SageMaker to easily scale our Training with out thinking to much about the infrastructure and how we iterate on our experiments. At the end we created an example, which can be used for any Text-Classification dataset and teacher & student combination for task-specific knowledge distillation.

I believe this will help companies improiving their production performance of Transformers even more by implementing task-specific knowledge distillation as one part of their MLOps pipeline.

You can find the code here and feel free to open a thread on the forum.

Thanks for reading. If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.