Outperform OpenAI GPT-3 with SetFit for text-classification
In many Machine Learning applications, the amount of available labeled data is a barrier to producing a high-performing model. In the last 2 years developments have shown that you can overcome this data limitation by using Large Language Models, like OpenAI GPT-3 together wit a few examples as prompts at inference time to achieve good results. These developments are improving the missing labeled data situation but are introducing a new problem, which is the access and cost of Large Language Models.
But a group of research led by Intel Labs and the UKP Lab, Hugging Face released an new approach, called "SetFit" (https://arxiv.org/abs/2209.11055), that can be used to create high accuracte text-classification models with limited labeled data. SetFit is outperforming GPT-3 in 7 out of 11 tasks, while being 1600x smaller.
In this blog, you will learn how to use SetFit to create a text-classification model with only a 8
labeled samples per class, or 32
samples in total. You will also learn how to improve your model by using hyperparamter tuning.
You will learn how to:
- Setup Development Environment
- Create Dataset
- Fine-Tune Classifier with SetFit
- Use Hyperparameter search to optimize results
Why SetFit is better
Compared to other few-shot learning methods, SetFit has several unique features:
🗣 No prompts or verbalisers: Current techniques for few-shot fine-tuning require handcrafted prompts. SetFit dispenses with prompts altogether by generating rich embeddings directly from text examples. 🏎 Fast to train: SetFit doesn't require large-scale models like T0 or GPT-3 to achieve high accuracy. 🌎 Multilingual support: SetFit can be used with any Sentence Transformer on the Hub.
Now we know why SetFit is amazing, let's get started. 🚀
Note: This tutorial was created and run on a g4dn.xlarge AWS EC2 Instance including a NVIDIA T4.
1. Setup Development Environment
Our first step is to install the Hugging Face Libraries, including SetFit. Running the following cell will install all the required packages.
2. Create Dataset
We are going to use the ag_news dataset, which a news article classification dataset with 4
classes: World (0), Sports (1), Business (2), Sci/Tech (3).
The test split of the dataset contains 7600
examples, which is will be used to evaluate our model. The train split contains 120000
examples, which is a nice amount of data for fine-tuning a regular model.
But to shwocase SetFit, we wanto to create a dataset with only a 8
labeled samples per class, or 32
data points.
3. Fine-Tune Classifier with SetFit
When using SetFit we first fine-tune a Sentence Transformer model using our labeled data and contrastive training, where positive and negative pairs are created by in-class and out-class selection. The second step a classification head is trained on the encoded embeddings with their respective class labels.
As Sentence Transformers we are going to use sentence-transformers/all-mpnet-base-v2. (you could replace the model with any available sentence transformer on hf.co).
The Python SetFit package is implementing useful classes and functions to make the fine-tuning process straightforward and easy. Similar to the Hugging Face Trainer class, SetFits implmenets the SetFitTrainer
class is responsible for the training loop.
4. Use Hyperparameter search to optimize result
The SetFitTrainer
provides a hyperparameter_search()
method that you can use to find the perefect hyperparameters for the data. SetFit
is leveraging optuna
under the hood to perform the hyperparameter search. To use the hyperparameter search, we need to define a model_init
method, which creates our model for every "run" and a hp_space
method that defines the hyperparameter search space.
After running 100
trials (runs) the bes model was found with the following hyperparameters:
{'learning_rate': 2.2041595048800003e-05, 'num_epochs': 2, 'batch_size': 64, 'num_iterations': 20, 'seed': 34, 'max_iter': 182, 'solver': 'lbfgs', 'model_id': 'sentence-transformers/all-mpnet-base-v2'}
Achieving an accuracy of 0.873421052631579
, which is 1.1% better than the model we trained without hyperparameter search.
After, we have found the perfect hyperparameters we need to run a last training using those.
Conclusion
Thats it, we have created a high-performing text-classification model with only 32
labeled samples or 8 samples per class using the SetFit approach. Our SetFit classifier achieved an accuracy of 0.873421052631579
on the test set. For comparison a regular model fine-tuned on the whole dataset (12 000
) achieves a performance ~94% accuracy.
This means you with 375x less data you lose only ~7% accuracy. 🤯
This is huge! SetFit will help so many company to get started with text-classification and transformers, without the need to label a lot of data and compute power. Compared to LLM training s SetFit classifier takes less than 1 hour on a small GPU (NIVIDA T4) to train or less than $1 so to speak.
Thanks for reading! If you have any questions, feel free to contact me, through Github, or on the forum. You can also connect with me on Twitter or LinkedIn.