philschmid

Accelerate BERT inference with DeepSpeed-Inference on GPUs

Published on
12 min read
View Code

In this session, you will learn how to optimize Hugging Face Transformers models for GPU inference using DeepSpeed-Inference. The session will show you how to apply state-of-the-art optimization techniques using DeepSpeed-Inference. This session will focus on single GPU inference on BERT and RoBERTa models. By the end of this session, you will know how to optimize your Hugging Face Transformers models (BERT, RoBERTa) using DeepSpeed-Inference. We are going to optimize a BERT large model for token classification, which was fine-tuned on the conll2003 dataset to decrease the latency from 30ms to 10ms for a sequence length of 128.

You will learn how to:

Let's get started! 🚀

This tutorial was created and run on a g4dn.xlarge AWS EC2 Instance including an NVIDIA T4.


Quick Intro: What is DeepSpeed-Inference

DeepSpeed-Inference is an extension of the DeepSpeed framework focused on inference workloads. DeepSpeed Inference combines model parallelism technology such as tensor, pipeline-parallelism, with custom optimized cuda kernels. DeepSpeed provides a seamless inference mode for compatible transformer based models trained using DeepSpeed, Megatron, and HuggingFace. For a list of compatible models please see here. As mentioned DeepSpeed-Inference integrates model-parallelism techniques allowing you to run multi-GPU inference for LLM, like BLOOM with 176 billion parameters. If you want to learn more about DeepSpeed inference:

1. Setup Development Environment

Our first step is to install Deepspeed, along with PyTorch, Transfromers and some other libraries. Running the following cell will install all the required packages.

Note: You need a machine with a GPU and a compatible CUDA installed. You can check this by running nvidia-smi in your terminal. If your setup is correct, you should get statistics about your GPU.

!pip install torch==1.11.0 torchvision==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 --upgrade -q
!pip install deepspeed==0.7.0 --upgrade -q
!pip install transformers[sentencepiece]==4.21.1 --upgrade -q
!pip install datasets evaluate[evaluator]==0.2.2 seqeval --upgrade -q

Before we start. Let's make sure all packages are installed correctly.

import re
import torch

# check deepspeed installation
report = !python3 -m deepspeed.env_report
r = re.compile('.*ninja.*OKAY.*')
assert any(r.match(line) for line in report) == True, "DeepSpeed Inference not correct installed"

# check cuda and torch version
torch_version, cuda_version = torch.__version__.split("+")
torch_version = ".".join(torch_version.split(".")[:2])
cuda_version = f"{cuda_version[2:4]}.{cuda_version[4:]}"
r = re.compile(f'.*torch.*{torch_version}.*')
assert any(r.match(line) for line in report) == True, "Wrong Torch version"
r = re.compile(f'.*cuda.*{cuda_version}.*')
assert any(r.match(line) for line in report) == True, "Wrong Cuda version"

2. Load vanilla BERT model and set baseline

After we set up our environment, we create a baseline for our model. We use the dslim/bert-large-NER, a fine-tuned BERT-large model on the English version of the standard CoNLL-2003 Named Entity Recognition dataset achieving an f1 score 95.7%.

To create our baseline, we load the model with transformers and create a token-classification pipeline.

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# Model Repository on huggingface.co
model_id="dslim/bert-large-NER"

# Load Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)

# Create a pipeline for token classification
token_clf = pipeline("token-classification", model=model, tokenizer=tokenizer,device=0)

# Test pipeline
example = "My name is Wolfgang and I live in Berlin"
ner_results = token_clf(example)
print(ner_results)
# [{'entity': 'B-PER', 'score': 0.9971501, 'index': 4, 'word': 'Wolfgang', 'start': 11, 'end': 19}, {'entity': 'B-LOC', 'score': 0.9986046, 'index': 9, 'word': 'Berlin', 'start': 34, 'end': 40}]

Create a Baseline with evaluate using the evaluator and the conll2003 dataset. The Evaluator class allows us to evaluate a model/pipeline on a dataset using a defined metric.

from evaluate import evaluator
from datasets import load_dataset

# load eval dataset
eval_dataset = load_dataset("conll2003", split="validation")

# define evaluator
task_evaluator = evaluator("token-classification")

# run baseline
results = task_evaluator.compute(
    model_or_pipeline=token_clf,
    data=eval_dataset,
    metric="seqeval",
)

print(f"Overall f1 score for our model is {results['overall_f1']*100:.2f}%")
print(f"The avg. Latency of the model is {results['latency_in_seconds']*1000:.2f}ms")
# Overall f1 score for our model is 95.76
# The avg. Latency of the model is 18.70ms

Our model achieves an f1 score of 95.8% on the CoNLL-2003 dataset with an average latency across the dataset of 18.9ms.

3. Optimize BERT for GPU using DeepSpeed InferenceEngine

The next and most important step is to optimize our model for GPU inference. This will be done using the DeepSpeed InferenceEngine. The InferenceEngine is initialized using the init_inference method. The init_inference method expects as parameters atleast:

  • model: The model to optimize.
  • mp_size: The number of GPUs to use.
  • dtype: The data type to use.
  • replace_with_kernel_inject: Whether inject custom kernels.

You can find more information about the init_inference method in the DeepSpeed documentation or thier inference blog.

import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification,pipeline
from transformers import pipeline
from deepspeed.module_inject import HFBertLayerPolicy
import deepspeed

# Model Repository on huggingface.co
model_id="dslim/bert-large-NER"

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)

# init deepspeed inference engine
ds_model = deepspeed.init_inference(
    model=model,      # Transformers models
    mp_size=1,        # Number of GPU
    dtype=torch.half, # dtype of the weights (fp16)
    # injection_policy={"BertLayer" : HFBertLayerPolicy}, # replace BertLayer with DS HFBertLayerPolicy
    replace_method="auto", # Lets DS autmatically identify the layer to replace
    replace_with_kernel_inject=True, # replace the model with the kernel injector
)

# create acclerated pipeline
ds_clf = pipeline("token-classification", model=ds_model, tokenizer=tokenizer,device=0)

# Test pipeline
example = "My name is Wolfgang and I live in Berlin"
ner_results = ds_clf(example)
print(ner_results)

We can now inspect our model graph to see that the vanilla BertLayer has been replaced with an HFBertLayer, which includes the DeepSpeedTransformerInference module, a custom nn.Module that is optimized for inference by the DeepSpeed Team.

InferenceEngine(
  (module): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): DeepSpeedTransformerInference(
            (attention): DeepSpeedSelfAttention()
            (mlp): DeepSpeedMLP()
          )

we can validate this with a simple assert.

from deepspeed.ops.transformer.inference import DeepSpeedTransformerInference

assert isinstance(ds_model.module.bert.encoder.layer[0], DeepSpeedTransformerInference) == True, "Model not sucessfully initalized"

Now, lets run the same evaluation as for our baseline transformers model.

# run baseline
ds_results = task_evaluator.compute(
    model_or_pipeline=ds_clf,
    data=eval_dataset,
    metric="seqeval",
)

print(f"Overall f1 score for our model is {ds_results['overall_f1']*100:.2f}%")
print(f"The avg. Latency of the model is {ds_results['latency_in_seconds']*1000:.2f}ms")

# Overall f1 score for our model is 95.64
# The avg. Latency of the model is 9.33ms

Our DeepSpeed model achieves an f1 score of 95.6% on the CoNLL-2003 dataset with an average latency across the dataset of 9.33ms.

4. Evaluate the performance and speed

As the last step, we want to take a detailed look at the performance and accuracy of our optimized model. Applying optimization techniques, like graph optimizations or mixed-precision, not only impact performance (latency) those also might have an impact on the accuracy of the model. So accelerating your model comes with a trade-off.

In our example, did we achieve on the conll2003 evaluation dataset an f1 score of 95.8% with an average latency of 18.9ms for the vanilla model and for our optimized model an f1 score of 95.6% with an average latency of 9.33ms.

print(f"The optimized ds-model achieves {round(ds_results['overall_f1']/results['overall_f1'],4)*100:.2f}% accuracy of the vanilla transformers model.")

The optimized ds-model achieves 99.88% accuracy of the vanilla transformers model.

Now let's test the performance (latency) of our optimized model. We will use a payload with a sequence length of 128 for the benchmark. To keep it simple, we will use a python loop and calculate the avg, mean & p95 latency for our vanilla model and the optimized model.

from time import perf_counter
import numpy as np

payload="Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend "*2

print(f'Payload sequence length is: {len(tokenizer(payload)["input_ids"])}')

def measure_latency(pipe):
    latencies = []
    # warm up
    for _ in range(10):
        _ = pipe(payload)
    # Timed run
    for _ in range(300):
        start_time = perf_counter()
        _ = pipe(payload)
        latency = perf_counter() - start_time
        latencies.append(latency)
    # Compute run statistics
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    time_p95_ms = 1000 * np.percentile(latencies,95)
    return f"P95 latency (ms) - {time_p95_ms}; Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};", time_p95_ms

vanilla_model=measure_latency(token_clf)
ds_opt_model=measure_latency(ds_clf)

print(f"Vanilla model: {vanilla_model[0]}")
print(f"Optimized model: {ds_opt_model[0]}")
print(f"Improvement through optimization: {round(vanilla_model[1]/ds_opt_model[1],2)}x")

# Payload sequence length is: 128
# Vanilla model: P95 latency (ms) - 30.401047450277474; Average latency (ms) - 29.68 +\- 0.54;
# Optimized model: P95 latency (ms) - 10.401162500056671; Average latency (ms) - 10.10 +\- 0.17;
# Improvement through optimization: 2.92x

We managed to accelerate the BERT-Large model latency from 30.4ms to 10.40ms or 2.92x for sequence length of 128.

bert-latency

Conclusion

We successfully optimized our BERT-large Transformers with DeepSpeed-inference and managed to decrease our model latency from 30.4ms to 10.4ms or 2.92x while keeping 99.88% of the model accuracy. The results are impressive, but applying the optimization was as easy as adding one additional call to deepspeed.init_inference. But I have to say that this isn't a plug-and-play process you can transfer to any Transformers model, task, or dataset. Also, make sure to check if your model is compatible with DeepSpeed-Inference.


Thanks for reading! If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.