DEV Community

Cover image for How to Load model.tar.gz from S3 for Fine-Tuning with SageMaker Training Jobs
Atsushi Suzuki
Atsushi Suzuki

Posted on

How to Load model.tar.gz from S3 for Fine-Tuning with SageMaker Training Jobs

Having previously conducted pre-training for a Tabular BERT model with SageMaker Training Jobs, I proceeded to create a job for fine-tuning using the model.tar.gz file from the pre-training results.

Although the job is fundamentally similar to the pre-training process, certain adjustments were required in the following aspects, which I have compiled as a memo:

  • Extracting the tar file
  • Switching arguments between local environment and SageMaker using environment variables

Extracting the tar file

The pre-training job has saved the model.tar.gz file on S3.
Image description

The model.tar.gz file, which is stored on S3 as a result of the pre-training job, contains the following files: the model pytorch_model.bin, the configuration file config.json, the dictionary file vocab.nb, and the token-to-id conversion file vocab_token2id.bin. To load these files during fine-tuning, it is essential to devise a method for extracting the tar file upon job execution.

Initially, set the S3 path for the model.tar.gz file in the input_model section of the job file. Consequently, the model.tar.gz file will be placed in the /opt/ml/input/data/input_model/ (model_path) directory when the job is executed.

import sagemaker
from sagemaker.estimator import Estimator

session = sagemaker.Session()
role = sagemaker.get_execution_role()

estimator = Estimator(
    image_uri=<image-url>,
    role=role,
    instance_type="ml.g4dn.2xlarge",
    instance_count=1,
    base_job_name="tabformer-opt-fine-tuning",
    output_path="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
    code_location="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
    sagemaker_session=session,
    entry_point="fine-tuning.sh",
    dependencies=["tabformer-opt"],
    hyperparameters={
        "data_root": "/opt/ml/input/data/input_data/",
        "data_fname": "summary",
        "output_dir": "/opt/ml/model/",
        "model_path": "/opt/ml/input/data/input_model/",
    }
)
estimator.fit({
    "input_data": "s3://<bucket-name>/sagemaker/input_data/summary.csv",
    "input_model": "s3://<bucket-name>/sagemaker/output_data/pre_training/tabformer-opt-2022-12-16-07-00-45-931/output/model.tar.gz"
})
Enter fullscreen mode Exit fullscreen mode

Next, include the following in the fine-tuning execution file tabformer_bert_fine_tuning.py:

        with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
            mytar.extractall(path.join(args.model_path, f'model'))

            token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
            vocab_file = path.join(args.model_path, f"model/vocab.nb")
            pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
Enter fullscreen mode Exit fullscreen mode

The tarfile.open() function reads the model.tar.gz file, and mytar.extractall(path.join(args.model_path, f'model')) extracts the contents under the /opt/ml/input/data/input_model/model/ directory.

This allows you to load the extracted files, such as with token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin").

Switching arguments between local environment and SageMaker using environment variables

With this setup, you can now load the model.tar.gz file from S3. However, there may be cases where you want to change the source of the file when performing fine-tuning locally.

To handle such situations, you can use os.getenv('SM_MODEL_DIR') to obtain the SageMaker environment variable SM_MODEL_DIR (the directory that will be uploaded to S3 upon container termination) and switch the source of the file between local and SageMaker (Job) environments.

    key = os.getenv('SM_MODEL_DIR')

    if key :
        with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
            mytar.extractall(path.join(args.model_path, f'model'))

            token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
            vocab_file = path.join(args.model_path, f"model/vocab.nb")
            pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
    else :
            vocab_file = path.join(args.model_path, f"vocab.nb")
            token2id_file = path.join(args.model_path, f"vocab_token2id.bin")
            pretrained_model = path.join(args.model_path, f"checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"checkpoint-500/config.json")
Enter fullscreen mode Exit fullscreen mode

Top comments (0)