DEV Community

Chaim Rand for AWS Community Builders

Posted on • Originally published at towardsdatascience.com

A Priority Based Scheduler for Amazon SageMaker Training Jobs

Optimizing the use of limited AI training accelerators — Part 2

Photo by Adrien Aletti on Unsplash

This post was created in collaboration with Max Rabin.

This is the second part of a series of posts on the topic of maximizing the utility of scarce AI resources. In the first post we noted the increasing limitations on the ability to scale up AI resources at will and, as a consequence, the growing trend of AI development teams to guarantee AI compute capacity by means such as building up an in-house AI server farm and/or reserving dedicated instances in the cloud. The scarcity of AI compute resources motivates the design of specialized scheduling solutions to minimize idle time and prioritize critical workloads. Please see our previous post in which we proposed a detailed list of requirements for such solutions. The approach we took there was to leverage the existing priority-based scheduler that comes with Kubernetes and align our training development workflow to its use. In this post we explore the option of maintaining our existing framework for training AI models and enhancing it with our own custom implementation of a priority-based scheduler. Importantly, the need for this type of solution is often motivated not just by the scarcity of AI resources, but also by the desire to increase control over the orchestration and prioritization of training workloads so as to reduce development costs. For example, even in a scenario of abundant capacity, you may choose to limit your use to a fixed number of training instances so as to cap your training expenditure.

For the purposes of this post, we will assume that our training framework of choice is AWS’s managed service for AI model training, Amazon SageMaker. The solution we will propose will use additional AWS services such as Amazon DynamoDB and AWS Lambda. The choice to demonstrate our solution using AWS services should not be viewed as endorsement. There are many cloud-based service offerings available and the best one for you will depend on the particular details of your project. Similar solutions to the one that we will describe can be designed on other cloud-based environments and/or using alternative cloud-based services.

The Traditional Method for Starting Up SageMaker Training Jobs

Traditionally, we would start up a SageMaker training job using the Amazon SageMaker Python SDK. In the code block below we use the SageMaker SDK (version 2.208) to run a PyTorch training workload on a single instance of type p5.48xlarge.

from sagemaker.pytorch import PyTorch  

# define job  
estimator = PyTorch(  
    role='<sagemaker role>',  
    entry_point='train.py',  
    instance_type='ml.p5.48xlarge',  
    instance_count=1,  
    framework_version='2.0.1',  
    py_version='py310',  
    tags=[{'Key': 'priority', 'Value': '100'}  
)  

# start job  
estimator.fit()
Enter fullscreen mode Exit fullscreen mode

When the estimator.fit() function is called, the SageMaker library uploads our code to Amazon S3 and then transforms the request to a boto3 SageMaker client create_training_job request (see here).

This method for starting up training jobs is dependent on the availability of the requested resources for its success. In our scenario of scarce AI resources, it is likely to fail more often than not. Although this can be partially mitigated by retaining provisioned compute instances for successive workloads, the API does not provide the appropriate tooling for maximizing their utility. Let’s suppose that we wish to utilize precisely two p5.48xlarge instances. To simplify our discussion, let’s assume that each training workload runs on a single instance. Typically, during an AI model development cycle there will be periods when there are more than two training workloads that are waiting to be processed. The existing API would try to start up a third p5.48xlarge instance and would most likely fail due to its limited availability. Even when there is instance availability, we may wish to limit our training to just our two designated instances to increase our control over the costs of training.

We require a new API for submitting jobs for training, one that does not immediately start up a new p5.48xlarge instance, but rather enters the jobs to a priority queue. And we need an associated job scheduler that manages the use of our two resources while prioritizing critical workloads.

Importantly, please note that as of the time of this writing, Amazon SageMaker does not support the option of training on reserved Amazon EC2 instances. And although Amazon SageMaker Savings Plans has similar properties to instance reservations, it does not guarantee instance capacity. In a previous post we addressed this limitation and proposed using SageMaker managed warm pools as an alternative method for retaining access to provisioned instances. For the remainder of the post, we will assume that we are able to attain two instances of our choice whether it be through this or some other method.

Priority-Based Scheduling for Amazon SageMaker

In this section we will describe the components of our proposed solution. We will use the AWS Serverless Application Model (SAM) specification. More specifically, we will create an AWS SAM template YAML file and gradually add the AWS resources that we need. Please see the documentation for details on how to define and deploy serverless solutions using AWS SAM.

AWS Architecture Diagram (by Author)

A Private API for Submitting Training Jobs

We start by using Amazon API Gateway to define a private REST API for submitting training job requests. We name the API training-job-queue. Later, we will add a POST method called add-job and modify our training-job creation code to use this method instead of the SageMaker client create_training_job API. The code block below contains the definition of the private API resource in SAM. In practice you will likely want to specify access limitations to the API and/or a method of authorization.

AWSTemplateFormatVersion: '2010-09-09'  
Transform: AWS::Serverless-2016-10-31  

Resources:  
  InternalAPI:  
    Type: AWS::Serverless::Api  
      # Auth: # Add access control to API  
      EndpointConfiguration:  
        Type: PRIVATE  
        # VPCEndpointIds: # Specify VPC Endpoint(s)  
      Name: training-job-queue  
      StageName: prod
Enter fullscreen mode Exit fullscreen mode

Define an AWS DynamoDB Table for Storing Training Job Requests

We will use an Amazon DynamoDB table named sagemaker-queue to store the submitted training workloads. Each entry will have the following fields:

  1. jobName: Stores the unique name of the training job.
  2. entryTime: Stores the date and time that the job was added.
  3. jobState: Stores the current state of the training job. The valid values are ‘pending’, ‘running’, and ‘preempted’.
  4. priority: Stores an integer value representing the relative priority of the job.
  5. jobDetails: Stores the details of the job request.

We define our DynamoDB table in our SAM template YAML file using the AWS::Serverless::SimpleTable resource.

 DynamoSMQueue:  
    Type: AWS::Serverless::SimpleTable  
    Properties:  
      PrimaryKey:  
        Name: jobName  
        Type: String  
      TableName: sagemaker-queue
Enter fullscreen mode Exit fullscreen mode

We define a function that creates a table entry from a given training job request. We assume that request contains the same contents as the input to the create_training_job API in JSON format. We further assume that the priority of the workload is entered as a key-value tag in the training job definition.

import json, boto3, datetime  

dynamodb = boto3.resource('dynamodb')  
table = dynamodb.Table('sagemaker-queue')  

def add_job_entry(job_json):  
    job_details = json.loads(job_json)  

    # extract job_name  
    job_name = job_details['TrainingJobName']  
    print(f'add entry {job_name}')  

    # get current time  
    entry_time = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")  

    # default priority is 0  
    priority = 0  

    # update priority based on tags  
    tags = job_details['Tags']  
    for tag in tags:  
        if tag['Key'] == 'priority':  
            priority = int(tag['Value'])  
            break  

    # create entry  
    entry = {  
       'jobName': job_name,  
       'entryTime': entry_time,  
       'jobState': 'pending',  
       'priority': priority,  
       'jobDetails': job_json  
    }  
    table.put_item(Item=entry) #TODO handle errors  
    print(f'Added job {job_name} to queue')
Enter fullscreen mode Exit fullscreen mode

The REST API add-job method that we will soon define will be programmed to call the add_job_entry function.

We define a second function that extracts the pending jobs from the database and returns them in order of priority. In the case that multiple jobs have the same priority, they are ordered according to the amount of time they have been waiting in the queue.

from boto3.dynamodb.conditions import Attr  

# Get a list of all pending jobs sorted by priority  
def get_pending_jobs():  
    response = table.scan(  
        ProjectionExpression='jobName, priority, entryTime',  
        FilterExpression=Attr('jobState').ne('running')  
    )  
    jobs = response.get('Items', [])  

    # sort jobs, first by priority (descending) and then by entryTime  
    sorted_jobs = sorted(jobs,  
                         key=lambda x: (-x['priority'], x['entryTime']))  

    return sorted_jobs
Enter fullscreen mode Exit fullscreen mode

The following utility functions will come in handy in the next sections.

# Get a jobName -> priority mapping of all running jobs  
def get_running_jobs_dict():  
    # Get all running jobs  
    response = table.scan(  
        ProjectionExpression="jobName, priority",  
        FilterExpression=Attr('jobState').eq('running')  
    )  
    jobs = response.get('Items', [])  

    running_jobs = {job['jobName']: job['priority'] for job in jobs}  

    return running_jobs  

# Print the queue state  
def print_queue_state():  
    response = table.scan(  
        ProjectionExpression='jobName, jobState, priority'  
    )  
    jobs = response.get('Items', [])  

    print_table = []  
    for job in jobs:  
        print_table.append([job['jobName'], job['jobState'], job['priority']])  

    # sort by priority  
    sorted_table = sorted(print_table,  
                         key=lambda x: -x[2])  
    # Print the table  
    from tabulate import tabulate  
    print(tabulate(sorted_table, headers=['Job Name', 'State', 'Priority']))  

# get job details  
def get_job_details(job_name):  
    response = table.get_item(  
        Key={'jobName': job_name},  
        ProjectionExpression='jobDetails'  
    )  
    return json.loads(response.get('Item').get('jobDetails'))  

# get job state or None if the job does not exist  
def get_job_state(job_name):  
    response = table.get_item(  
        Key={'jobName': job_name},  
        ProjectionExpression='jobState'  
    )  
    job = response.get('Item')  
    return job.get('jobState') if job else None  

# update the job state  
def update_job_state(job_name, new_state):  
    table.update_item(  
        Key={'jobName': job_name},  
        UpdateExpression="SET jobState = :new_state",  
        ExpressionAttributeValues={":new_state": new_state}  
    )  
    print(f'Update job {job_name} to {new_state}')  

# remove a job entry  
def remove_job(job_name):  
    table.delete_item(  
        Key={'jobName': job_name}  
    )  
    print(f'Removed job {job_name} from queue')
Enter fullscreen mode Exit fullscreen mode

Both our choice of DynamoDB and its usage (e.g., our use of the Scan API rather than the Query API) assume that the overall number of jobs in our queue will be in the dozens, at most. For a larger scale solution, you may be better off with a heavier duty database (e.g., one that performs the sorting operation for you) or a more sophisticated use of DynamoDB (e.g., see here).

Define the Training Job Queue Manager

The main component of our solution is the training job scheduler. Here we implement a rather simple manager that performs the following steps:

  1. Extract the list of queued jobs, ordered by priority. If none exist, return.
  2. Discover unused instance capacity. For each free instance, start one pending job on SageMaker. If no jobs remain after that, return.
  3. Calculate the number of SageMaker jobs in the Stopping state. If greater than the number of pending jobs, return.
  4. Assess the need for preemption of running SageMaker jobs by comparing their priorities to those of our pending jobs.
# set the limit on total number of instances/jobs  
MAX_CAPACITY = 2  

sagemaker = boto3.client('sagemaker')  

# apply a queue stamp to identify that the job came from the queue  
def apply_qstamp(job_name):  
    return f'{job_name}-qstamp-{datetime.now().strftime("%d%H%M")}'  

# strip the queue stamp  
def strip_qstamp(job_name):  
    return job_name.split('-qstamp-')[0]  

# start a SageMaker job and update job entry in queue  
def start_job(job_name):  
    print(f'start job {job_name}')  
    job_details = get_job_details(job_name)  
    job_details['TrainingJobName'] = apply_qstamp(job_name)  
    if(job_details):  
        # start job with detail from queue  
        # (you may optinally overwrite fields such as the iam role)  
        response = sagemaker.create_training_job(**job_details)  
        if response['ResponseMetadata']['HTTPStatusCode'] == 200:  
            print(f'started job {job_name}')  
            update_job_state(job_name, 'running')  

# preempt a SageMaker job and update job entry in queue  
def preempt_job(job_name):  
    print(f'preempt job {job_name}')  
    response = sagemaker.stop_training_job(TrainingJobName=job_name)  
    if response['ResponseMetadata']['HTTPStatusCode'] == 200:  
        print(f'preempted job {job_name}')  
        update_job_state(strip_qstamp(job_name), 'preempted')  

# get SageMaker jobs  
def get_sagemaker_jobs(status):  
    running = sagemaker.list_training_jobs(StatusEquals=status)  
    return running.get('TrainingJobSummaries', [])  

# queue manager  
def manage_queue():  
    # extract pending jobs to run  
    pending = get_pending_jobs()  

    if not pending:  
        return  

    if len(pending) > MAX_CAPACITY:  
        pending = pending[:MAX_CAPACITY]  

    # get running sagemaker jobs  
    running = get_sagemaker_jobs('InProgress')  
    total_running = len(running)  

    # get stopping sagemaker jobs  
    stopping = get_sagemaker_jobs('Stopping')  
    total_stopping = len(stopping)  

    # calculate the number of free instances   
    free_slots = MAX_CAPACITY - total_running - total_stopping  

    jobs_to_start = min(len(pending), free_slots)  

    # for each free instance, start a job  
    for i in range(jobs_to_start):  
        start_job(pending[i].get('jobName'))  

    still_pending = pending[jobs_to_start:]  

    if not still_pending:  
        return  

    # assume that 'total_stopping' number of jobs will start soon  
    test_for_preemption = len(still_pending) - total_stopping  
    if test_for_preemption <= 0:  
        return  

    # check if preemption is required  
    test_priority = still_pending[total_stopping:]  

    running_jobs = get_running_jobs_dict()  
    priority_dict = {}  
    for job in running:  
        job_name = job['TrainingJobName']  
        priority_dict[job_name] = running_jobs[strip_qstamp(job_name)]  

    # sort running jobs from lowest to highest priority  
    sorted_running = sorted(priority_dict.items(), key=lambda item: item[1])  

    index = 0  
    while index < test_for_preemption and \  
          test_priority[index].get('priority') > sorted_running[index][1]:  
        preempt_job(sorted_running[index][0])  
        index = index + 1
Enter fullscreen mode Exit fullscreen mode

Important notes:

  1. Our implementation is highly optimistic in the sense that we assume that all the jobs that are inserted are valid and that we will be able to start them up on SageMaker without issue. In practice, appropriate error handling should be added (e.g., removing faulty jobs from the queue with appropriate logging).
  2. In a production environment, we would need to take into consideration the likely occurrence of a race condition when our queue_manager is triggered by multiple concurrent events. There are several ways of addressing this problem (e.g., see here) including enforcing atomicity (e.g., by setting our Lambda function concurrency to one), using some form of locking mechanism (e.g., as done here), or making our function idempotent. Here we have taken the approach of what we call “optimistic idempotence”, where we rely on appropriate use of the API and on the idempotency of our underlying calls to the SageMaker APIs.
  3. We emphasize that our implementation is naïve. In practice, we recommend a more sophisticated algorithm that 1) accounts for the use of different types of instances and jobs that require more than one instance, 2) takes all edge cases into consideration, and 3) is tailored towards the specific needs of your project.

Define the AWS Lambda Function

The next component of the solution is the Lambda function. The following code block includes the SAM definition of our serverless function. We program the function to run on two different types of events: any call to add-job on our private API gateway and a change to the state of a SageMaker training job.

 ManagedTrainingJobQueue:  
    Type: AWS::Serverless::Function  
    Properties:  
      CodeUri: job-queue/ # the directory containing our index.py file  
      Handler: index.lambda_handler  
      Runtime: python3.12  
      Architectures:  
        - arm64 # use graviton  
      Policies: # allow access to SageMaker and DynamoDB  
        - !Sub "arn:${AWS::Partition}:iam::aws:policy/AmazonSageMakerFullAccess"  
        - DynamoDBCrudPolicy:  
            TableName: !Ref DynamoSMQueue  
      Events:  
        CreateTraining:  
          Type: Api  
          Properties:  
            Path: /add-job  
            Method: post  
            RestApiId: !Ref InternalAPI  
        SageMakerEvent:  
          Type: EventBridgeRule  
          Properties:  
            Pattern:  
              source:  
                - aws.sagemaker  
              detail-type:  
                - SageMaker Training Job State Change  
              detail:  
                TrainingJobStatus:  
                  - "Completed"  
                  - "Failed"  
                  - "Stopped"
Enter fullscreen mode Exit fullscreen mode

The lambda_handler function is implemented as follows:

def lambda_handler(event, context):  
    # identify source of event and take appropriate action  
    if 'requestContext' in event and 'apiId' in event['requestContext']:  
        print('Lambda triggerred by API Gateway')  
        job_details = json.loads(event.get('body'))  
        add_job_entry(job_details)  
    elif 'source' in event and event['source'] == 'aws.sagemaker':  
        print('Lambda triggerred by SageMaker job state change')  
        job_name = event['detail']['TrainingJobName']  
        job_status = event['detail']['TrainingJobStatus']  
        print(f'{job_name} status changed to {job_status}')  

        # strip qstamp from job_name  
        job_name = strip_qstamp(job_name)  

        if job_status in ['Completed' , 'Failed']:  
            remove_job(job_name)  
        elif job_status == 'Stopped':  
            # check if it was manually stopped or preempted by queue manager  
            if get_job_state(job_name) == 'preempted':  
                print(f'job {job_name} preemption completed')  
            else:  
                print(f'job {job_name} {job_status}, remove from queue')  
                remove_job(job_name)  

    # in all cases invoke queue manager  
    manage_queue()
Enter fullscreen mode Exit fullscreen mode

Intercept the Create Training Job Request

The final modification required to make our solution complete is to intercept the call to the SageMaker create_training_job API and reroute it to our add-job method. We do this by overriding the _intercept_create_request function of the SageMaker Session class:

from sagemaker.pytorch import PyTorch  
from sagemaker.session import Session  
import requests, logging  
logger = logging.getLogger('sagemaker')  

def submit_to_training_queue(job):  
    logger.info(f'Adding training-job {job['TrainingJobName']} to queue')  
    logger.debug('train request: {json.dumps(job, indent=4)}')  

    vpce='<vpc endpoint>' # insert id of vpc endpoint  
    region='us-east-1' # specify region  
    url=f'https://{vpce}.execute-api.{region}.vpce.amazonaws.com/prod/add-job'  
    headers = {'x-apigw-api-id': '<api-id>'} # insert api gateway id  

    # submit job  
    response = requests.post(url, headers=headers, json=job)  

class QueueTrainingJobSession(Session):  
    def _intercept_create_request(self, request, create, func_name = None):  
        """This function intercepts the create job request  

        Args:  
          request (dict): the create job request  
          create (functor): a functor calls the sagemaker client create method  
          func_name (str): the name of the function needed intercepting  
        """  
        if func_name == 'train':  
            submit_to_training_queue(request)  
        else:  
            super()._intercept_create_request(request,create,func_name)  

# define job  
estimator = PyTorch(  
    role='<sagemaker role>',  
    entry_point='train.py',  
    instance_type='ml.p5.48xlarge',  
    instance_count=1,  
    framework_version='2.0.1',  
    py_version='py310',  
    tags=[{'Key': 'priority', 'Value': '100'},  
    keep_alive_period_in_seconds=60, # keep warm for 1 minute  
    # use our custom Session class  
    sagemaker_session=QueueTrainingJobSession()  
)  

estimator.fit(wait=False)
Enter fullscreen mode Exit fullscreen mode

Use Case Example

To test our solution we submit the following sequence of jobs. After each call we print the status of the queue (using the print_queue_state function) and sleep for twenty seconds.

  1. Start job1 with priority 1.
  2. Start job2 with priority 2.
  3. Start job3 with priority 1.
  4. Start job4 with priority 3.

The first two jobs are immediately submitted to SageMaker and updated to the running state. Since the third job has low priority and we have precisely two training instances, it remains in the pending state and waits its turn. After submitting the first three jobs, the queue state appears as:

Job Name    State      Priority  
----------  -------  ----------  
job2        running           2  
job1        running           1  
job3        pending           1
Enter fullscreen mode Exit fullscreen mode

The fourth job we submit has a higher priority than all of the jobs in the queue. Consequently, the running job with the lowest priority, job1, is preempted. The corresponding SageMaker job is stopped and once the instance is released, the queue state becomes:

Job Name    State        Priority  
----------  ---------  ----------  
job4        running             3  
job2        running             2  
job1        preempted           1  
job3        pending             1
Enter fullscreen mode Exit fullscreen mode

The SageMaker job running job2 is the first to finish, job2 is removed from the queue, and our preempted job is resumed:

Job Name    State      Priority  
----------  -------  ----------  
job4        running           3  
job1        running           1  
job3        pending           1
Enter fullscreen mode Exit fullscreen mode

Once job4 is completed, it too is removed from the queue, making room for job3. The remaining jobs are also run to completion, ultimately leaving our queue empty.

Summary

The increasing difficulty of acquiring AI compute capacity has forced AI development teams to reevaluate the processes they use for training AI models. The approach we have demonstrated in this post is to augment the traditional APIs for training models with a custom-made priority queue and an associated job scheduler. Importantly, the proposal we have put forth should be viewed as a general scheme, not as a production-worthy solution. Appropriate modifications and enhancements would be required to address the specifics needs of your project.

Top comments (0)