philschmid

Optimize & Deploy BERT on AWS inferentia2

Published on
11 min read
View Code

In this end-to-end tutorial, you will learn how to optimize and deploy BERT on AWS Inferentia2. We will reduce latency down to 4ms latency for BERT-base with a sequence length of 128.

You will learn how to:

  1. Convert BERT to AWS Neuron (Inferentia2) with optimum-neuron
  2. Create a custom inference.py script for text-classification
  3. Upload the neuron model and inference script to Amazon S3
  4. Deploy a Real-time Inference Endpoint on Amazon SageMaker
  5. Run and evaluate Inference performance of BERT on Inferentia2

Quick intro: AWS Inferentia 2

AWS inferentia (Inf2) are purpose-built EC2 for deep learning (DL) inference workloads. Inferentia 2 is the successor of AWS Inferentia, which promises to deliver up to 4x higher throughput and up to 10x lower latency.

instance sizeacceleratorsNeuron Coresaccelerator memoryvCPUCPU Memory
inf2.xlarge1232416
inf2.8xlarge123232128
inf2.24xlarge61219296384
inf2.48xlarge1224384192768

Additionally, inferentia 2 will support the writing of custom operators in c++ and new datatypes, including FP8 (cFP8).

Let's get started! πŸš€


If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.

1. Convert BERT to AWS Neuron (Inferentia2) with optimum-neuron

We are going to use the optimum-neuron. πŸ€— Optimum Neuron is the interface between the πŸ€— Transformers library and AWS Accelerators including AWS Trainium and AWS Inferentia. It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks.

As a first step, we need to install the optimum-neuron and other required packages.

Tip: If you use Amazon SageMaker Notebook Instances or Studio, you can go with the conda_python3 conda kernel.

# Install the required packages
!pip install "git+https://github.com/huggingface/optimum-neuron.git@b94d534cc0160f1e199fae6ae3a1c7b804b49e30"  --upgrade

# !python -m pip install "sagemaker==2.169.0"  --upgrade
!python -m pip install "git+https://github.com/aws/sagemaker-python-sdk.git"  --upgrade

After we have installed the optimum-neuron we can load and convert our model. We are going to use the yiyanghkust/finbert-tone model. FinBERT is a BERT model pre-trained on financial communication text. The purpose is to enhance financial NLP research and practice. It is trained on the following three financial communication corpus. The total corpora size is 4.9B tokens. This released finbert-tone model is the FinBERT model fine-tuned on 10,000 manually annotated (positive, negative, neutral) sentences from analyst reports.

model_id = "yiyanghkust/finbert-tone"

At the time of writing, the AWS Inferentia2 does not support dynamic shapes for inference, which means that the input size needs to be static for compiling and inference.

In simpler terms, this means when the model is converted with a sequence length of 16. The model can only run inference on inputs with the same shape. We will use the optimum-cli to convert our model with a sequence length of 128 and a batch size of 1.

When using a t2.medium instance, the compiling takes around 2-3 minutes

%%bash -s "$model_id"
MODEL_ID=$1
SEQUENCE_LENGTH=128
BATCH_SIZE=1
OUTPUT_DIR=tmp/ # used to store temproary files
echo "Model ID: $MODEL_ID"

# exporting model
optimum-cli export neuron \\
  --model $MODEL_ID \\
  --sequence_length $SEQUENCE_LENGTH \\
  --batch_size $BATCH_SIZE \\
  $OUTPUT_DIR

2. Create a custom inference.py script for text-classification

The Hugging Face Inference Toolkit supports zero-code deployments on top of the pipeline feature from πŸ€— Transformers. This allows users to deploy Hugging Face transformers without an inference script [Example].

Currently is, this feature not supported with AWS Inferentia2, which means we need to provide an inference.py for running inference. But optimum-neuron has integrated support for the πŸ€— Transformers pipeline feature. That way, we can use the optimum-neuron to create a pipeline for our model.

If you want to know more about the inference.py script check out this example. It explains, amongst other things, what the model_fn and predict_fn are.

!mkdir code

In addition to our inference.py script we need to provide a requirements.txt, which installs the latest version of the optimum-neuron package, which comes with pipeline support for AWS Inferentia2. Note: This is a temporary solution until the optimum-neuron package is updated inside the DLC.

%%writefile code/requirements.txt
git+https://github.com/huggingface/optimum-neuron.git@b94d534cc0160f1e199fae6ae3a1c7b804b49e30

We use the NEURON_RT_NUM_CORES=1 to ensure that each HTTP worker uses 1 Neuron core to maximize throughput.

%%writefile code/inference.py
import os
# To use one neuron core per worker
os.environ["NEURON_RT_NUM_CORES"] = "1"
from optimum.neuron.pipelines import pipeline
import torch
import torch_neuronx

def model_fn(model_dir):
    # load local converted model into pipeline
    pipe = pipeline("text-classification", model=model_dir)
    return pipe

def predict_fn(data, pipeline):
    inputs = data.pop("inputs", data)
    parameters = data.pop("parameters", None)

    # pass inputs with all kwargs in data
    if parameters is not None:
        prediction = pipeline(inputs, **parameters)
    else:
        prediction = pipeline(inputs)
    # postprocess the prediction
    return prediction

3. Upload the neuron model and inference script to Amazon S3

Before we can deploy our neuron model to Amazon SageMaker we need to create a model.tar.gz archive with all our model artifacts saved into, e.g. model.neuron and upload this to Amazon S3.

To do this we need to set up our permissions. Currently inf2 instances are only available in the us-east-2 region [REF]. Therefore we need to force the region to us-east-2.

import os

os.environ["AWS_DEFAULT_REGION"] = "us-east-2" # need to set to ohio region

Now let's create our SageMaker session and upload our model to Amazon S3.

import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
assert sess.boto_region_name == "us-east-2", "region must be us-east-2"

Next, we create our model.tar.gz.The inference.py script will be placed into a code/ folder.

# copy inference.py into the code/ directory of the model directory.
!cp -r code/ tmp/code/
# create a model.tar.gz archive with all the model artifacts and the inference.py script.
%cd tmp
!tar zcvf model.tar.gz *
%cd ..

Now we can upload our model.tar.gz to our session S3 bucket with sagemaker.

from sagemaker.s3 import S3Uploader

# create s3 uri
s3_model_path = f"s3://{sess.default_bucket()}/neuronx/{model_id}"

# upload model.tar.gz
s3_model_uri = S3Uploader.upload(local_path="tmp/model.tar.gz",desired_s3_uri=s3_model_path)
print(f"model artifcats uploaded to {s3_model_uri}")

4. Deploy a Real-time Inference Endpoint on Amazon SageMaker

After we have uploaded our model.tar.gz to Amazon S3 can we create a custom HuggingfaceModel. This class will be used to create and deploy our real-time inference endpoint on Amazon SageMaker.

The inf2.xlarge instance type is the smallest instance type with AWS Inferentia2 support. It comes with 1 Inferentia2 chip with 2 Neuron Cores. This means we can use 2 Model server workers to maximize throughput and run 2 inferences in parallel.

from sagemaker.huggingface.model import HuggingFaceModel

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data=s3_model_uri,        # path to your model and script
   role=role,                      # iam role with permissions to create an Endpoint
   transformers_version="4.28.1",  # transformers version used
   pytorch_version="1.13.0",       # pytorch version used
   py_version='py38',              # python version used
   model_server_workers=2,         # number of workers for the model server
)

# Let SageMaker know that we've already compiled the model
huggingface_model._is_compiled_model = True

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,      # number of instances
    instance_type="ml.inf2.xlarge" # AWS Inferentia Instance
)

5. Run and evaluate Inference performance of BERT on Inferentia2

The .deploy() returns an HuggingFacePredictor object which can be used to request inference.

data = {
  "inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .",
}

res = predictor.predict(data=data)
res

We managed to deploy our neuron compiled BERT to AWS Inferentia on Amazon SageMaker. Now, let's test its performance of it. As a dummy load test will we use threading to send 10000 requests to our endpoint with 10 threads.

Note: When running the load test we environment was based in europe and the endpoint is deployed in us-east-2.

import threading

number_of_threads = 10
number_of_requests = int(10000 // number_of_threads)
print(f"number of threads: {number_of_threads}")
print(f"number of requests per thread: {number_of_requests}")

def send_rquests():
    for _ in range(number_of_requests):
        predictor.predict(data={"inputs": "it 's a charming and often affecting journey ."})
    print("done")

# Create multiple threads
threads = [threading.Thread(target=send_rquests) for _ in range(number_of_threads) ]
# start all threads
[t.start() for t in threads]
# wait for all threads to finish
[t.join() for t in threads]

Sending 10000 requests with 10 threads takes around 86 seconds. This means we can run around ~116 inferences per second. But keep in mind that includes the network latency from Europe to us-east-2. When we inspect the latency of the endpoint through cloudwatch, we can see that the average latency is around 4ms. This means we can run around 500 inferences per second without network overhead or framework overhead.

print(f"<https://console.aws.amazon.com/cloudwatch/home?region={sess.boto_region_name}#metricsV2:graph=~(metrics~(~(~'AWS*2fSageMaker~'ModelLatency~'EndpointName~'{predictor.endpoint_name}~'VariantName~'AllTraffic)>)~view~'timeSeries~stacked~false~region~'{sess.boto_region_name}~start~'-PT5M~end~'P0D~stat~'Average~period~30);query=~'*7bAWS*2fSageMaker*2cEndpointName*2cVariantName*7d*20{predictor.endpoint_name}")

The average latency for our BERT model is 3.8-4.1ms for a sequence length of 128.

performance.png

To clean up, we can delete the model and endpoint.

predictor.delete_model()
predictor.delete_endpoint()

Let's take a closer look at price performance. We are using inf2.xlarge, which is 0.99$/h. If we assume our latency is 4ms for a sequence length of 128 with BERT-base, one inference request costs 0.0000011$and 1 million requests would then cost 0.55$.

Conclusion

In this tutorial, we have shown how easy it is to optimize and deploy BERT on AWS Inferentia2 using the optimum-neuron package. We have reduced the latency down to 4ms for BERT-base with a sequence length of 128.

The AWS Inferentia2 instances provide amazing cost performance. If you are currently running BERT-like models in production on GPU instances, I recommend switching to Inferentia2. if you are interested in exploring this in more detail, let me know.


Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.