Train and Deploy open Embedding Models on Amazon SageMaker
Embedding models are crucial for successful RAG applications, but they're often trained on general knowledge, which limits their effectiveness for company or domain specific adoption. Customizing embedding for your domain specific data can significantly boost the retrieval performance of your RAG Application. With the new release of Sentence Transformers 3 and the Hugging Face Embedding Container, it's easier than ever to fine-tune and deploy embedding models.
In this blog post, we'll show you how to fine-tune and deploy a custom embedding model on Amazon SageMaker using the new Hugging Face Embedding Container. We'll use the Sentence Transformers 3 library to fine-tune a model on a custom dataset and deploy it on Amazon SageMaker for inference. We will fine-tune BAAI/bge-base-en-v1.5 for financial RAG applications using a synthetic dataset from the 2023_10 NVIDIA SEC Filing.
- Setup development environment
- Create and prepare the dataset
- Fine-tune Embedding model on Amazon SageMaker
- Deploy & Test fine-tuned Embedding Model on Amazon SageMaker
Note: This blog is an extension and dedicated version to my Fine-tune Embedding models for Retrieval Augmented Generation (RAG) version, specifically tailored to run on Amazon SageMaker.
What is new with Sentence Transforemrs 3?
Sentence Transformers v3 introduces a new trainer that makes it easier to fine-tune and train embedding models. This update includes enhanced components like diverse datasets, updated loss functions, and a streamlined training process, improving the efficiency and flexibility of model development.
What is the Hugging Face Embedding Container?
The Hugging Face Embedding Container is a new purpose-built Inference Container to easily deploy Embedding Models in a secure and managed environment. The DLC is powered by Text Embedding Inference (TEI) a blazing fast and memory efficient solution for deploying and serving Embedding Models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5. TEI implements many features such as:
Note: This blog was created and validated on ml.g5.xlarge
for training and ml.c6i.2xlarge
for inference instance.
1. Setup Development Environment
Our first step is to install Hugging Face Libraries we need on the client to correctly prepare our dataset and start our training/evaluations jobs.
If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.
2. Create and prepare the dataset
An embedding dataset typically consists of text pairs (question, answer/context) or triplets that represent relationships or similarities between sentences. The dataset format you choose or have available will also impact the loss function you can use. Common formats for embedding datasets:
- Positive Pair: Text Pairs of related sentences (query, context | query, answer), suitable for tasks like similarity or semantic search, example datasets:
sentence-transformers/sentence-compression
,sentence-transformers/natural-questions
. - Triplets: Text triplets consisting of (anchor, positive, negative), example datasets
sentence-transformers/quora-duplicates
,nirantk/triplets
. - Pair with Similarity Score: Sentence pairs with a similarity score indicating how related they are, example datasets:
sentence-transformers/stsb
,PhilipMay/stsb_multi_mt
Learn more at Dataset Overview.
We are going to use philschmid/finanical-rag-embedding-dataset, which includes 7,000 positive text pairs of questions and corresponding context from the 2023_10 NVIDIA SEC Filing.
The dataset has the following format
We are going to use the FileSystem integration to upload our dataset to S3. We are using the sess.default_bucket()
, adjust this if you want to store the dataset in a different S3 bucket. We will use the S3 path later in our training script.
3. Fine-tune Embedding model on Amazon SageMaker
We are now ready to fine-tune our model. We will use the SentenceTransformerTrainer from sentence-transformers
to fine-tune our model. The SentenceTransformerTrainer
makes it straightfoward to supervise fine-tune open Embedding Models, as it is a subclass of the Trainer
from the transformers
. We prepared a script run_mnr.py which will loads the dataset from disk, prepare the model, tokenizer and start the training.
The SentenceTransformerTrainer
makes it straightfoward to supervise fine-tune open Embedding supporting:
- Integrated Components: Combines datasets, loss functions, and evaluators into a unified training framework.
- Flexible Data Handling: Supports various data formats and easy integration with Hugging Face datasets.
- Versatile Loss Functions: Offers multiple loss functions for different training tasks.
- Multi-Dataset Training: Facilitates simultaneous training with multiple datasets and different loss functions.
- Seamless Integration: Easy saving, loading, and sharing of models within the Hugging Face ecosystem.
In order to create a sagemaker training job we need an HuggingFace
Estimator. The Estimator handles end-to-end Amazon SageMaker training and deployment tasks. The Estimator manages the infrastructure use. Amazon SagMaker takes care of starting and managing all the required ec2 instances for us, provides the correct huggingface container, uploads the provided scripts and downloads the data from our S3 bucket into the container at /opt/ml/input/data
. Then, it starts the training job by running.
Note: Make sure that you include the
requirements.txt
in thesource_dir
if you are using a custom training script. We recommend to just clone the whole repository.
Lets first define our trainings parameter. Those are passed as cli arguments to our training script. We are going to use the BAAI/bge-base-en-v1.5
model, which is a pre-trained model on a large corpus of English text. We will use the MultipleNegativesRankingLoss
in combination with the MatryoshkaLoss
. This approach allows us to leverage the efficiency and flexibility of Matryoshka embeddings, enabling different embedding dimensions to be utilized without significant performance trade-offs. The MultipleNegativesRankingLoss
is a great loss function if you only have positive pairs as it adds in batch negative samples to the loss function to have per sample n-1 negative samples.
We can now start our training job, with the .fit()
method passing our S3 path to the training script.
In our example the training BGE Base with Flash Attention 2 (SDPA) for 3 epochs with a dataset of 6,3k train samples and 700 eval samples took 645 seconds (~10minutes) on a ml.g5.xlarge
(1.2575 /h) or ~0.2.
4. Deploy & Test fine-tuned Embedding Model on Amazon SageMaker
We are going to use the Hugging Face Embedding Container a purpose-built Inference Container to easily deploy Embedding Models in a secure and managed environment. The DLC is powered by Text Embedding Inference (TEI) a blazing fast and memory efficient solution for deploying and serving Embedding Models.
To retrieve the new Hugging Face Embedding Container in Amazon SageMaker, we can use the get_huggingface_llm_image_uri
method provided by the sagemaker SDK. This method allows us to retrieve the URI for the desired Hugging Face Embedding Container. Important to note is that TEI has 2 different versions for cpu and gpu, so we create a helper function to retrieve the correct image uri based on the instance type.
We can now create a HuggingFaceModel
using the container uri and the S3 path to our model. We also need to set our TEI configuration.
After we have created the HuggingFaceModel
we can deploy it to Amazon SageMaker using the deploy method. We will deploy the model with the ml.c6i.2xlarge
instance type.
SageMaker will now create our endpoint and deploy the model to it. This can take ~5 minutes. After our endpoint is deployed we can run inference on it. We will use the predict
method from the predictor to run inference on our endpoint.
We trained our model with the Matryoshka Loss means that the semantic meaning is frontloaded. To use the different mathryshoka dimension we need to manually truncate our embeddings manually. Below is an example on how you would truncate the embeddings to 256 dimension, which is 1/3 of the original size. If we check our training logs we can see that the NDCG metric for 768 is 0.823
and for 256 0.818
meaning we preserve > 99% accuracy.
Awesome! 🚀 Now that we can generate embeddings and integrate your endpoint into your RAG application.
To clean up, we can delete the model and endpoint.
Thanks for reading! If you have any questions or feedback, please let me know on Twitter or LinkedIn.