Accelerated document embeddings with Hugging Face Transformers and AWS Inferentia
Photo by Max Bender on Unsplash
notebook: sentence-transformers-huggingface-inferentia
The adoption of BERT and Transformers continues to grow. Transformer-based models are now not only achieving state-of-the-art performance in Natural Language Processing but also for Computer Vision, Speech, Time-Series and especially Semantic Search. 💬 🖼 🎤 ⏳
Semantic search seeks to improve search accuracy by understanding the content of the search query. The idea behind semantic search is to embed all documents into a vector space. At search time, the query is embedded into the same vector space and the closest embeddings are found. Transformers have taken over this domain and are currently achieving state-of-the-art performance but they are often slow and search shouldn’t be slow to feel natural.
AWS’s take to solve the performance challenge was to design a custom machine learning chip designed for optimized inference workload called AWS Inferentia. AWS says that AWS Inferentia “delivers up to 80% lower cost per inference and up to 2.3X higher throughput than comparable current generation GPU-based Amazon EC2 instances.”
The real value of AWS Inferentia instances compared to GPU comes through the multiple Neuron Cores available on each device. A Neuron Core is the custom accelerator inside AWS Inferentia. Each Inferentia chip comes with 4x Neuron Cores. This enables you to either load 1 model on each core (for high throughput) or 1 model across all cores (for lower latency).
In this end-to-end tutorial, you will learn how to speed up Sentence-Transformers like SBERT for creating sentence embedding using Hugging Face Transformers, Amazon SageMaker, and AWS Inferentia to achieve sub 5ms latency and up to 1000 requests per second per instance.
You will learn how to:
- 1. Convert your Hugging Face sentence transformers to AWS Neuron (Inferentia)
- 2. Create a custom
inference.py
script forsentence-embeddings
- 3. Create and upload the neuron model and inference script to Amazon S3
- 4. Deploy a Real-time Inference Endpoint on Amazon SageMaker
- 5. Run and evaluate Inference performance of BERT on Inferentia
Let’s get started! 🚀
If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.
1. Convert your Hugging Face Sentence Transformers to AWS Neuron
We are going to use the AWS Neuron SDK for AWS Inferentia. The Neuron SDK includes a deep learning compiler, runtime, and tools for converting and compiling PyTorch and TensorFlow models to neuron compatible models, which can be run on EC2 Inf1 instances.
As a first step, we need to install the Neuron SDK and the required packages.
Tip: If you are using Amazon SageMaker Notebook Instances or Studio you can go with the conda_python3
conda kernel.
1 # Set Pip repository to point to the Neuron repository2 !pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com34 # Install Neuron PyTorch5 !pip install torch-neuron==1.9.1.* neuron-cc[tensorflow] sagemaker>=2.79.0 transformers==4.12.3 --upgrade
After we have installed the Neuron SDK we can convert load and convert our model. Neuron models are converted using torch_neuron
with its trace
method similar to torchscript
. You can find more information in our documentation.
We are going to use the sentence-transformers/all-MiniLM-L6-v2 model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
1 model_id = "sentence-transformers/all-MiniLM-L6-v2"
At the time of writing, the AWS Neuron SDK does not support dynamic shapes, which means that the input size needs to be static for compiling and inference.
In simpler terms, this means when the model is compiled with an input of batch size 1 and sequence length of 16. The model can only run inference on inputs with the same shape.
When using a t2.medium
instance the compiling takes around 2-3 minutes
1 import os2 import tensorflow # to workaround a protobuf version conflict issue3 import torch4 import torch.neuron5 from transformers import AutoTokenizer, AutoModel678 # load tokenizer and model9 tokenizer = AutoTokenizer.from_pretrained(model_id)10 model = AutoModel.from_pretrained(model_id, torchscript=True)1112 # create dummy input for max length 12813 dummy_input = "dummy input which will be padded later"14 max_length = 12815 embeddings = tokenizer(dummy_input, max_length=max_length, padding="max_length",return_tensors="pt")16 neuron_inputs = tuple(embeddings.values())1718 # compile model with torch.neuron.trace and update config19 model_neuron = torch.neuron.trace(model, neuron_inputs)20 model.config.update({"traced_sequence_length": max_length})2122 # save tokenizer, neuron model and config for later use23 save_dir="tmp"24 os.makedirs("tmp",exist_ok=True)25 model_neuron.save(os.path.join(save_dir,"neuron_model.pt"))26 tokenizer.save_pretrained(save_dir)27 model.config.save_pretrained(save_dir)
2. Create a custom inference.py script for sentence-embeddings
The Hugging Face Inference Toolkit supports zero-code deployments on top of the pipeline feature from 🤗 Transformers. This allows users to deploy Hugging Face transformers without an inference script [Example].
Currently is this feature not supported with AWS Inferentia, which means we need to provide an inference.py
for running inference.
If you would be interested in support for zero-code deployments for inferentia let us know on the forum.
To use the inference script, we need to create an inference.py
script. In our example, we are going to overwrite the model_fn
to load our neuron model and the predict_fn
to create a sentence-embeddings pipeline.
If you want to know more about the inference.py
script check out this example. It explains amongst other things what the model_fn
and predict_fn
are.
1 !mkdir code
We are using the NEURON_RT_NUM_CORES=1
to make sure that each HTTP worker uses 1 Neuron core to maximize throughput.
1 %%writefile code/inference.py23 import os4 from transformers import AutoConfig, AutoTokenizer5 import torch6 import torch.neuron7 import torch.nn.functional as F89 # To use one neuron core per worker10 os.environ["NEURON_RT_NUM_CORES"] = "1"1112 # saved weights name13 AWS_NEURON_TRACED_WEIGHTS_NAME = "neuron_model.pt"1415 # Helper: Mean Pooling - Take attention mask into account for correct averaging16 def mean_pooling(model_output, attention_mask):17 token_embeddings = model_output[0] #First element of model_output contains all token embeddings18 input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()19 return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)202122 def model_fn(model_dir):23 # load tokenizer and neuron model from model_dir24 tokenizer = AutoTokenizer.from_pretrained(model_dir)25 model = torch.jit.load(os.path.join(model_dir, AWS_NEURON_TRACED_WEIGHTS_NAME))26 model_config = AutoConfig.from_pretrained(model_dir)2728 return model, tokenizer, model_config293031 def predict_fn(data, model_tokenizer_model_config):32 # destruct model and tokenizer33 model, tokenizer, model_config = model_tokenizer_model_config3435 # Tokenize sentences36 inputs = data.pop("inputs", data)37 encoded_input = tokenizer(38 inputs,39 return_tensors="pt",40 max_length=model_config.traced_sequence_length,41 padding="max_length",42 truncation=True,43 )44 # convert to tuple for neuron model45 neuron_inputs = tuple(encoded_input.values())4647 # Compute token embeddings48 with torch.no_grad():49 model_output = model(*neuron_inputs)5051 # Perform pooling52 sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])5354 # Normalize embeddings55 sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)5657 # return dictonary, which will be json serializable58 return {"vectors": sentence_embeddings[0].tolist()}
1 Overwriting code/inference.py
3. Create and upload the neuron model and inference script to Amazon S3
Before we can deploy our neuron model to Amazon SageMaker we need to create a model.tar.gz
archive with all our model artifacts saved into tmp/
, e.g. neuron_model.pt
and upload this to Amazon S3.
To do this we need to set up our permissions.
1 import sagemaker2 import boto33 sess = sagemaker.Session()4 # sagemaker session bucket -> used for uploading data, models and logs5 # sagemaker will automatically create this bucket if it not exists6 sagemaker_session_bucket=None7 if sagemaker_session_bucket is None and sess is not None:8 # set to default bucket if a bucket name is not given9 sagemaker_session_bucket = sess.default_bucket()1011 try:12 role = sagemaker.get_execution_role()13 except ValueError:14 iam = boto3.client('iam')15 role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']1617 sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)1819 print(f"sagemaker role arn: {role}")20 print(f"sagemaker bucket: {sess.default_bucket()}")21 print(f"sagemaker session region: {sess.boto_region_name}")
Next, we create our model.tar.gz
.The inference.py
script will be placed into a code/
folder.
1 # copy inference.py into the code/ directory of the model directory.2 !cp -r code/ tmp/code/3 # create a model.tar.gz archive with all the model artifacts and the inference.py script.4 %cd tmp5 !tar zcvf model.tar.gz *6 %cd ..
Now we can upload our model.tar.gz
to our session S3 bucket with sagemaker
.
1 from sagemaker.s3 import S3Uploader23 # create s3 uri4 s3_model_path = f"s3://{sess.default_bucket()}/neuron/{model_id}"56 # upload model.tar.gz7 s3_model_uri = S3Uploader.upload(local_path="tmp/model.tar.gz",desired_s3_uri=s3_model_path)8 print(f"model artifcats uploaded to {s3_model_uri}")
4. Deploy a Real-time Inference Endpoint on Amazon SageMaker
After we have uploaded our model.tar.gz
to Amazon S3 can we create a custom HuggingfaceModel
. This class will be used to create and deploy our real-time inference endpoint on Amazon SageMaker.
1 from sagemaker.huggingface.model import HuggingFaceModel234 # create Hugging Face Model Class5 huggingface_model = HuggingFaceModel(6 model_data=s3_model_uri, # path to your model and script7 role=role, # iam role with permissions to create an Endpoint8 transformers_version="4.12", # transformers version used9 pytorch_version="1.9", # pytorch version used10 py_version='py37', # python version used11 )1213 # Let SageMaker know that we've already compiled the model via neuron-cc14 huggingface_model._is_compiled_model = True1516 # deploy the endpoint endpoint17 predictor = huggingface_model.deploy(18 initial_instance_count=1, # number of instances19 instance_type="ml.inf1.xlarge" # AWS Inferentia Instance20 )
5. Run and evaluate Inference performance of BERT on Inferentia
The .deploy()
returns an HuggingFacePredictor
object which can be used to request inference.
1 data = {2 "inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .",3 }45 res = predictor.predict(data=data)6 res
We managed to deploy our neuron compiled BERT to AWS Inferentia on Amazon SageMaker. Now, let’s test its performance of it. As a dummy load test will we loop and send 10000 synchronous requests to our endpoint.
1 # send 10000 requests2 for i in range(10000):3 resp = predictor.predict(4 data={"inputs": "it 's a charming and often affecting journey ."}5 )
Let’s inspect the performance in cloudwatch.
1 print(f"https://console.aws.amazon.com/cloudwatch/home?region={sess.boto_region_name}#metricsV2:graph=~(metrics~(~(~'AWS*2fSageMaker~'ModelLatency~'EndpointName~'{predictor.endpoint_name}~'VariantName~'AllTraffic))~view~'timeSeries~stacked~false~region~'{sess.boto_region_name}~start~'-PT5M~end~'P0D~stat~'Average~period~30);query=~'*7bAWS*2fSageMaker*2cEndpointName*2cVariantName*7d*20{predictor.endpoint_name}")
The average latency for our MiniLM model is 3-4.5ms
for a sequence length of 128.
Delete model and endpoint
To clean up, we can delete the model and endpoint.
1 predictor.delete_model()2 predictor.delete_endpoint()
Conclusion
We successfully managed to compile a Sentence Transformer to an AWS Inferentia compatible Neuron Model. After that we deployed our Neuron model to Amazon SageMaker using the new Hugging Face Inference DLC. We managed to achieve 3-4.5ms
latency per neuron core, which is faster than CPU in terms of latency, and achieves a higher throughput than GPUs since we ran 4 models in parallel. We can achieve an throughput of up to 1000 documents per second with a 4ms latency on a 128 sequence length per inf1.xlarge
instance, which costs around ~200$ per month.
If you or you company are currently using a Sentence Transformers for semantic search tasks (document-emebddings, sentence-embeddings, ranking), and 3-4.5ms latency meets your requirements you should switch to AWS Inferentia. This will not only save costs, but can also increase efficiency and performance for your models.
We are planning to do a more detailed case study on cost-performance of transformers in the future, so stay tuned!
Also if you want to learn more about accelerating transformers you should also check out Hugging Face optimum.
Thanks for reading! If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.