Creating document embeddings with Hugging Face's Transformers & Amazon SageMaker

March 8, 20227 minute readView Code

Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a real-time inference endpoint running a Sentence Transformers for document embeddings. Currently, the SageMaker Hugging Face Inference Toolkit supports the pipeline feature from Transformers for zero-code deployment. This means you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefore we only need to provide an environment variable HF_TASK and HF_MODEL_ID when creating our endpoint and the Inference Toolkit will take care of it. This is a great feature if you are working with existing pipelines.

If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via an script. The Hugging Face Inference Toolkit allows the user to override the default methods of the HuggingFaceHandlerService.

The custom module can override the following methods:

  • model_fn(model_dir) overrides the default method for loading a model. The return value model will be used in thepredict_fn for predictions.
    • model_dir is the the path to your unzipped model.tar.gz.
  • input_fn(input_data, content_type) overrides the default method for pre-processing. The return value data will be used in predict_fn for predictions. The inputs are:
    • input_data is the raw body of your request.
    • content_type is the content type from the request header.
  • predict_fn(processed_data, model) overrides the default method for predictions. The return value predictions will be used in output_fn.
    • model returned value from model_fn methond
    • processed_data returned value from input_fn method
  • output_fn(prediction, accept) overrides the default method for post-processing. The return value result will be the response to your request (e.g.JSON). The inputs are:
    • predictions is the result from predict_fn.
    • accept is the return accept type from the HTTP Request, e.g. application/json.

In this example are we going to use Sentence Transformers to create sentence embeddings using a mean pooling layer on the raw representation.

NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances

Development Environment and Permissions


%pip install sagemaker --upgrade
import sagemaker
assert sagemaker.__version__ >= "2.75.0"

Install git and git-lfs

# For notebook instances (Amazon Linux)
!sudo yum update -y
!curl -s | sudo bash
!sudo yum install git-lfs git -y
# For other environments (Ubuntu)
!sudo apt-get update -y
!curl -s | sudo bash
!sudo apt-get install git-lfs git -y


If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.

import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Create custom an script

To use the custom inference script, you need to create an script. In our example, we are going to overwrite the model_fn to load our sentence transformer correctly and the predict_fn to apply mean pooling.

We are going to use the sentence-transformers/all-MiniLM-L6-v2 model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.

!mkdir code
%%writefile code/
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer
def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    # return dictonary, which will be json serializable
    return {"vectors": sentence_embeddings[0].tolist()}

Create model.tar.gz with inference script and model

To use our we need to bundle it into a model.tar.gz archive with all our model-artifcats, e.g. pytorch_model.bin. The script will be placed into a code/ folder. We will use git and git-lfs to easily download our model from and upload it to Amazon S3 so we can use it when creating our SageMaker endpoint.

repository = "sentence-transformers/all-MiniLM-L6-v2"
  1. Download the model from with git clone.
!git lfs install
!git clone$repository
  1. copy into the code/ directory of the model directory.
!cp -r code/ $model_id/code/
  1. Create a model.tar.gz archive with all the model artifacts and the script.
%cd $model_id
!tar zcvf model.tar.gz *
  1. Upload the model.tar.gz to Amazon S3:
!aws s3 cp model.tar.gz $s3_location
#    upload: ./model.tar.gz to s3://sagemaker-us-east-1-558105141721/custom_inference/all-MiniLM-L6-v2/model.tar.gz

Create custom HuggingfaceModel

After we have created and uploaded our model.tar.gz archive to Amazon S3. Can we create a custom HuggingfaceModel class. This class will be used to create and deploy our SageMaker endpoint.

from sagemaker.huggingface.model import HuggingFaceModel
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data=s3_location,       # path to your model and script
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.12",  # transformers version used
   pytorch_version="1.9",        # pytorch version used
   py_version='py38',            # python version used
# deploy the endpoint endpoint
predictor = huggingface_model.deploy(

Request Inference Endpoint using the HuggingfacePredictor

The .deploy() returns an HuggingFacePredictor object which can be used to request inference.

data = {
  "inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .",
res = predictor.predict(data=data)
#   {'vectors': [0.005078191868960857, -0.0036594511475414038, .....]}

Delete model and endpoint

To clean up, we can delete the model and endpoint.



We managed to provide a custom inference script to overwrite default methods for model loading and running inference. This allowed us to use Sentence Transformers models for creating sentence embeddings with minimal code changes.

Custom Inference scripts are an easy and nice way to customize the inference pipeline of the Hugging Face Inference Toolkit when your pipeline is not represented in the pipelines API of Transformers or when you want to add custom logic.

Thanks for reading! If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.