K-Fold as Cross-Validation with a BERT Text-Classification Example
K-fold is a cross-validation method used to estimate the skill of a machine learning model on unseen data. It is commonly used to validate a model, because it is easy to understand, to implement and results are having a higher informative value than regular Validation Methods.
Cross-validation is a resampling procedure used to validate machine learning models on a limited data set. The procedure has a single parameter called K that refers to the number of groups that a given data sample is to be split into, that's the reason why it´s called K-fold.
The choice of K is usually 5 or 10, but there is no formal rule. As K is getting larger, the resampling subsets are getting smaller. The number of K also defines how often your Machine Learning Model is trained. Most of the time we split our data into train/validation sets in 80%-20%, 90%-10% or 70%-30% and train our model once. In cross-validation, we split our model K times and then train. Be aware that this will result in longer training processes.
K-Fold steps:
- Shuffle the dataset.
- Split the dataset into
K
groups. - For each unique group
g
:- Take
g
as a test dataset. - Take the remaining groups as a training data set.
- Fit a model on the training set and evaluate it on the test set.
- Retain the evaluation score and discard the model.
- Take
- Summarize the skill of the model using the sample of model evaluation scores.
The results of a K-fold cross-validation run are often summarized with the mean of the model scores.
Scitkit-Learn Example
The example is a simple implementation with scikit-learn and a scalar numpy array .
Simpletransformers Example (BERT Text-Classification)
The example is an implementation for a BERT Text-Classification
with
simpletransformers
library and Scikit-Learn
.
Benefits of K-Fold Cross-Validation
Using all data: By using K-fold cross-validation we are using the complete dataset, which is helpful if we have a small dataset because you split and train your model K times to see its performance instead of wasting X% for your validation dataset.
Getting more metrics: Most of the time you have one result of metrics, but with K-Fold you´ll be able to get K results of the metric and can have a deeper look into your model's performance.
Achieving higher precision: By validating your model against multiple “validation-sets” we get a higher level of reliability. Let’s imagine the following example: We have 3 speakers and 1500 recordings (500 for each speaker). If we do a simple train/validation split the result could be very different depending on the split.