Generative AI for Document Understanding with Hugging Face and Amazon SageMaker

May 23, 202314 minute readView Code

In this tutorial, you will learn how to fine-tune and deploy Donut-base for document-understand/document-parsing using Hugging Face Transformers and Amazon SageMaker. Donut is a new document-understanding model achieving state-of-art performance with an MIT-license, which allows it to be used for commercial purposes compared to other models like LayoutLMv2/LayoutLMv3.

You will learn how to:

  1. Setup Development Environment
  2. Load SROIE dataset
  3. Preprocess and upload dataset for Donut
  4. Fine-tune Donut model on Amazon SageMaker
  5. Deploy Donut model on Amazon SageMaker

Quick intro: Document Understanding Transformer (Donut) by ClovaAI

Document Understanding Transformer (Donut) is a new Transformer model for OCR-free document understanding. It doesn't require an OCR engine to process scanned documents but is achieving state-of-the-art performances on various visual document understanding tasks, such as visual document classification or information extraction (a.k.a. document parsing). Donut is a multimodal sequence-to-sequence model with a vision encoder (Swin Transformer) and text decoder (BART). The encoder receives the images and computes it into an embedding, which is then passed to the decoder, which generates a sequence of tokens.


Now we know how Donut works, so let's get started. 🚀

1. Setup Development Environment

The first step is to install the required libraries and setup the environment. We will use the Amazon SageMaker Python SDK to interact with SageMaker. We will also use Hugging Face Transformers & Datasets to preprocess the data.

!pip install "transformers[sentencepiece]==4.26.0" "datasets[s3]==2.9.0" sagemaker --upgrade --quiet

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find here more about it.

import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

2. Load SROIE dataset

We will use the SROIE dataset a collection of 1000 scanned receipts including their OCR, more specifically we will use the dataset from task 2 "Scanned Receipt OCR". The available dataset on Hugging Face (darentang/sroie) is not compatible with Donut. Thats why we will use the original dataset together with the imagefolder feature of datasets to load our dataset. Learn more about loading image data here.

Note: The test data for task2 is sadly not available. Meaning that we end up only with 626 images.

First, we will clone the repository, extract the dataset into a separate folder and remove the unnecessary files.

# clone repository
git clone
# copy data
cp -r ICDAR-2019-SROIE/data ./
# clean up
rm -rf ICDAR-2019-SROIE
rm -rf data/box

Now we have two folders inside the data/ directory. One contains the images of the receipts and the other contains the OCR text. The next step is to create a metadata.json file that contains the information about the images including the OCR-text. This is necessary for the imagefolder feature of datasets.

The metadata.json should look at the end similar to the example below.

{"file_name": "0001.png", "text": "This is a golden retriever playing with a ball"}
{"file_name": "0002.png", "text": "A german shepherd"}

In our example will "text" column contain the OCR text of the image, which will later be used for creating the Donut specific format.

import json
from pathlib import Path
import shutil
# define paths
base_path = Path("data")
metadata_path = base_path.joinpath("key")
image_path = base_path.joinpath("img")
# define metadata list
metadata_list = []
# parse metadata
for file_name in metadata_path.glob("*.json"):
  with open(file_name, "r") as json_file:
    # load json file
    data = json.load(json_file)
    # create "text" column with json string
    text = json.dumps(data)
    # add to metadata list if image exists
    if image_path.joinpath(f"{file_name.stem}.jpg").is_file():
      # delete json file
# write jsonline file
with open(image_path.joinpath('metadata.jsonl'), 'w') as outfile:
    for entry in metadata_list:
        json.dump(entry, outfile)
# remove old meta data

Good Job! Now we can load the dataset using the imagefolder feature of datasets.

from datasets import load_dataset
dataset = load_dataset("imagefolder", data_dir=image_path, split="train")
print(f"Dataset has {len(dataset)} images")
print(f"Dataset features are: {dataset.features.keys()}")

Now, lets take a closer look at our dataset

import random
random_sample = random.randint(0, len(dataset))
print(f"Random sample is {random_sample}")
print(f"OCR text is {dataset[random_sample]['text']}")


3. Preprocess and upload dataset for Donut

As we learned in the introduction, Donut is a sequence-to-sequence model with a vision encoder and text decoder. When fine-tuning the model we want it to generate the "text" based on the image we pass it. Similar to NLP tasks, we have to tokenize and preprocess the text. Before we can tokenize the text, we need to transform the JSON string into a Donut compatible document.

current JSON string

  "company": "ADVANCO COMPANY",
  "date": "17/01/2018",
  "total": "7.00"

Donut document

<s></s><s_company>ADVANCO COMPANY</s_company><s_date>17/01/2018</s_date><s_address>NO 1&3, JALAN WANGSA DELIMA 12, WANGSA LINK, WANGSA MAJU, 53300 KUALA LUMPUR</s_address><s_total>7.00</s_total></s>

To easily create those documents the ClovaAI team has created a json2token method, which we extract and then apply.

new_special_tokens = [] # new tokens which will be added to the tokenizer
task_start_token = "<s>"  # start of task token
eos_token = "</s>" # eos token of tokenizer
def json2token(obj, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
    Convert an ordered JSON object into a token sequence
    if type(obj) == dict:
        if len(obj) == 1 and "text_sequence" in obj:
            return obj["text_sequence"]
            output = ""
            if sort_json_key:
                keys = sorted(obj.keys(), reverse=True)
                keys = obj.keys()
            for k in keys:
                if update_special_tokens_for_json_key:
                    new_special_tokens.append(fr"<s_{k}>") if fr"<s_{k}>" not in new_special_tokens else None
                    new_special_tokens.append(fr"</s_{k}>") if fr"</s_{k}>" not in new_special_tokens else None
                output += (
                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                    + fr"</s_{k}>"
            return output
    elif type(obj) == list:
        return r"<sep/>".join(
            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
        # excluded special tokens for now
        obj = str(obj)
        if f"<{obj}/>" in new_special_tokens:
            obj = f"<{obj}/>"  # for categorical special tokens
        return obj
def preprocess_documents_for_donut(sample):
    # create Donut-style input
    text = json.loads(sample["text"])
    d_doc = task_start_token + json2token(text) + eos_token
    # convert all images to RGB
    image = sample["image"].convert('RGB')
    return {"image": image, "text": d_doc}
proc_dataset =
print(f"Sample: {proc_dataset[45]['text']}")
print(f"New special tokens: {new_special_tokens + [task_start_token] + [eos_token]}")

The next step is to tokenize our text and encode the images into tensors. Therefore we need to load DonutProcessor, add our new special tokens and adjust the size of the images when processing from [1920, 2560] to [720, 960] to need less memory and have faster training.

from transformers import DonutProcessor
# Load processor
model_id = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_id)
# add new special tokens to tokenizer
processor.tokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens + [task_start_token] + [eos_token]})
# we update some settings which differ from pretraining; namely the size of the images + no rotation required
# resizing the image to smaller sizes from [1920, 2560] to [960,1280]
processor.feature_extractor.size = [720,960] # should be (width, height)
processor.feature_extractor.do_align_long_axis = False

Now, we can prepare our dataset, which we will use for the training later.

def transform_and_tokenize(sample, processor=processor, split="train", max_length=512, ignore_id=-100):
    # create tensor from image
        pixel_values = processor(
            sample["image"], random_padding=split == "train", return_tensors="pt"
    except Exception as e:
        print(f"Error: {e}")
        return {}
    # tokenize document
    input_ids = processor.tokenizer(
    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = ignore_id  # model doesn't need to predict pad token
    return {"pixel_values": pixel_values, "labels": labels, "target_sequence": sample["text"]}
# need at least 32-64GB of RAM to run this
processed_dataset =,remove_columns=["image","text"])

Before we can upload our dataset to S3 for training we want to split the dataset into train and test sets.

processed_dataset = processed_dataset.train_test_split(test_size=0.1)

After that is done we use the new FileSystem integration to upload our dataset to S3. We are using the sess.default_bucket(), adjust this if you want to store the dataset in a different S3 bucket. We will use the S3 path later in our training script.

# save train_dataset to s3
training_input_path = f's3://{sess.default_bucket()}/processed/donut-sagemaker/train'
# save train_dataset to s3
test_input_path = f's3://{sess.default_bucket()}/processed/donut-sagemaker/test'
print("uploaded data to:")
print(f"training dataset to: {training_input_path}")
print(f"test dataset to: {test_input_path}")

4. Fine-tune Donut model on Amazon SageMaker

After we have processed our dataset, we can start training our model using a Amazon SageMaker training job using the HuggingFace Estimator. The Estimator handles end-to-end Amazon SageMaker training and deployment tasks. The Estimator manages the infrastructure use. SagMaker takes care of starting and managing all the required ec2 instances for us, provides the correct huggingface container, uploads the provided scripts and downloads the data from our S3 bucket into the container at /opt/ml/input/data. Then, it starts the training job by running.

Important steps we need to think of is that we extended the DonutProcessor earlier and added special tokens, which we need to pass through to our training script. We also need to pass the image_size and max_length to our training script.

As pretrained model we will use naver-clova-ix/donut-base. The donut-base includes only the pre-trained weights and was introduced in the paper OCR-free Document Understanding Transformer by Geewok et al. and first released in this repository.

In addition to loading our model, we are resizing the embedding layer to match newly added tokens and adjusting the image_size of our encoder to match our dataset. We are also adding tokens for inference later.

import time
import json
from sagemaker.huggingface import HuggingFace
# define Training Job Name
job_name = f'huggingface-donut-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'
# stingify special tokens
special_tokens = ",".join(processor.tokenizer.special_tokens_map_extended["additional_special_tokens"])
# hyperparameters, which are passed into the training job
hyperparameters = {
  'model_id': model_id,                                # pre-trained model
  'special_tokens': json.dumps(special_tokens),        # special tokens which will be added to the tokenizer
  'dataset_path': '/opt/ml/input/data/training',       # path where sagemaker will save training dataset
  'epochs': 3,                                         # number of training epochs
  'per_device_train_batch_size': 8,                    # batch size for training
  'gradient_checkpointing': True,                    # batch size for training
  'lr': 4e-5,                                          # learning rate used during training
# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point          = '',        # train script
    source_dir           = 'scripts',         # directory which includes all the files needed for training
    instance_type        = 'ml.g5.2xlarge',   # instances type used for the training job
    instance_count       = 1,                 # the number of instances used for training
    base_job_name        = job_name,          # the name of the training job
    role                 = role,              # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size          = 100,               # the size of the EBS volume in GB
    transformers_version = '4.26',            # the transformers version used in the training job
    pytorch_version      = '1.13',            # the pytorch_version version used in the training job
    py_version           = 'py39',            # the python version used in the training job
    hyperparameters      =  hyperparameters

Lets start the training job and wait until it is finished. This will take around 30 minutes.

# define a data input dictonary with our uploaded s3 uris
data = {'training': training_input_path}
# starting the train job with our uploaded datasets as input, wait=True)

5. Deploy Donut model on Amazon SageMaker

During the training we copied a into out model.tar.gz which allows us now to easily deploy our model to SageMaker for inference. The implements a custom model_fn and predict_fn for our Donut model. The model_fn loads the model and processor and the predict_fn tokenizes the input and returns the prediction.

from sagemaker.huggingface import HuggingFaceModel
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(

Before we can deploy model with the HuggingFaceModel class we need to create a new serializer, which supports our image data. The Serializer are used in Predictor and in the predict method to serializer our data to a specific mime-type. The default serialzier for the HuggingFacePredcitor is a JSON serializer, but since we are not going to send text data to the endpoint we will use the DataSerializer.

from sagemaker.serializers import DataSerializer
# create a serializer for the data
image_serializer = DataSerializer(content_type='image/x-image') # using x-image to support multiple image formats
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
   instance_type= "ml.g5.2xlarge",
   serializer=image_serializer, # serializer for our image files.

SageMaker starts the deployment process by creating a SageMaker Endpoint Configuration and a SageMaker Endpoint. The Endpoint Configuration defines the model and the instance type.

Lets test by using a example from the test split.

from PIL import Image
import io
from random import randrange
from transformers.image_transforms import to_pil_image
import numpy as np
test_sample = processed_dataset["test"][randrange(0,len(processed_dataset["test"]))]
image = to_pil_image(np.array(test_sample["pixel_values"]))
def image_to_byte_array(image: Image) -> bytes:
  format = image.format if image.format else 'JPEG'
  # BytesIO is a file-like buffer stored in memory
  img_byte_arr = io.BytesIO()
  # expects a file-like as a argument, format=format)
  # Turn the BytesIO object back into a bytes object
  return img_byte_arr.getvalue()
res = predictor.predict(data=image_to_byte_array(image))
target = processor.token2json(test_sample["target_sequence"])
print(f"Reference:\n {target}")
print(f"Prediction:\n {res}")
#    Reference:
#     {'total': '41.87', 'date': '24/10/2017', 'company': 'GARDENIA BAKERIES (KL) SDN BHD', 'address': 'LOT 3, JALAN PELABUR 23/1, 40300 SHAH ALAM, SELANGOR.'}
#    Prediction:
#     {'total': '41.87', 'date': '24/10/2017', 'company': 'GARDENIA BAKERIES (KL) SDN BHD', 'address': 'LOT 3, JALAN PELABUR 23/1, 40300 SHAH ALAM, SELANGOR.'}


Awesome!! Our fine-tuned model parsed the document correctly and extracted the right values. The next step is to evalute our model on the test set. Since the model itself is a seq2seq is not that straightforward to evaluate.

To keep things simple we will use rogue short for Recall-Oriented Understudy for Gisting Evaluation. This metric does not behave like the standard accuracy: it will compare a generated text against a set of reference text. The rogue score is mostly used for summarization or machine translation tasks.

The higher the score the closer the generated text is to the reference text.

!pip install rouge-score py7zr
import evaluate
from tqdm import tqdm
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable parallelism in tokenizers library
# Metric
rogue = evaluate.load("rouge")
predictions, references = [], []
# iterate over dataset
for sample in tqdm(processed_dataset["test"],total=len(processed_dataset["test"])):
  image = to_pil_image(np.array(sample["pixel_values"]))
  prediction = predictor.predict(data=image_to_byte_array(image))
  reference = processor.token2json(sample["target_sequence"])
# compute scores
results = rogue.compute(predictions=predictions,references=references)
# {'rouge1': 0.8173891303548311, 'rouge2': 0.7266157117328251, 'rougeL': 0.8167726537736875, 'rougeLsum': 0.8144718562842747}

Our model achieves an rogue 1 score of 81.7% on the test set. The rogue1 refers to the overlap of unigrams (each word) between the prediction and reference.

Note: The evaluation we did was very simple.

In an inference test the model predicted for the address the value NO. 31G&33G, JALAN SETIA INDAH X ,U13/X 40170 SETIA ALAM and the ground truth was 'NO. 31G&33G, JALAN SETIA INDAH X,U13/X 40170 SETIA ALAM', where the only difference is the whitespace in between X and ,U13/X.

Clean up

To avoid unnecessary costs, we should delete the endpoint and the model.


Thanks for reading! If you have any questions, feel free to contact me on Twitter or LinkedIn.