A simple routine that can save you loads of money
Photo by Ziyou Zhang on Unsplash
This blog post was co-authored with my colleague Shay Margalit. It summarizes his research into how AWS Lambda functions can be used to increase the control over the usage and costs of the Amazon SageMaker training service. Interested? Please read on :).
We are fortunate (or very unfortunate - depending on who you ask) to be sharing a front row seat to an AI revolution that is expected by many to change the world as we know it. Powered by advances in hardware development and access to enormous amounts of data, this revolution is likely to impact many aspects of our daily lives - although precisely how, no one can say for sure. To support the growing appetite for artificial intelligence, the sizes of the underlying machine learning models are increasing rapidly as are the resources that are required to train them. The bottom line is that staying relevant in the AI development playing field requires a sizable investment into heavy, and expensive, machinery.
Cloud-based managed training services, such as Amazon SageMaker, Google Vertex AI, and Microsoft Azure ML, have lowered the entry barrier to AI development by enabling developers to train on machines that they could otherwise not afford. Although such services reduce the upfront costs of AI and enable you to pay only for the time you spend training, the potential for the variable costs to add up warrants careful planning of how the training services will be used and how they will contribute to your overall training expense. However, inevitably, things don't always go according to plan. To paraphrase an old Yiddish proverb "developers plan and the programming gods laugh". When the stakes are high, as when training AI models --- where an errant experiment can result in hundreds or thousands of dollars worth of wasted compute time, it is wise to institute multiple lines of defense.
First Line of Defense - Encourage Healthy Development Habits
The first line of defense should address the development practices of the ML algorithm engineers. Here are examples of some guiding principles you might consider:
- Encourage appropriate and cost-optimal use of the hardware resources used for training (e.g., see here).
- Identify and terminate failing experiments early.
- Increase price performance by regularly analyzing and optimizing runtime performance (e.g., see here).
While formulating and adapting AI development principles such as the ones above are likely to increase your productivity and reduce waste, they do not offer full protection against all possible failures. For example, a dedicated failure detection runtime process may not help address a situation in which a training experiment stalls (e.g., due to a deadlock in the training application's processes) but the training job remains active until it is actively stopped or times out.
Second Line of Defense - Deploy Cross-project Guardrails
In this post we propose instituting a second line of defense that monitors all of the training activities in the project (or organization), verifies their compliance with a predetermined set of rules, and takes appropriate action in the case that errant training experiments are identified. One way to do this is to use dedicated server-less functions that are triggered at different stages of a training job and programmed to evaluate the job's state and optionally stop or restart it (possibly with changes to the job settings), accordingly. In the next sections we will demonstrate a few examples of how to use AWS Lambda as a second line of defense against errant Amazon SageMaker training experiments.
Disclaimers
Although we have chosen Amazon SageMaker and AWS Lambda for our demonstrations, the contents of the post are just as relevant to other services and similar functionality can be implemented for them. Please do not interpret our choice of these services as an endorsement of their use over their alternatives. There are multiple options available for cloud-based training each with their own advantages and disadvantages. The best choice for you will greatly depend on the details of your project.
While we will share a few Python examples of server-less code, we will not go into the details of how to create and deploy them as AWS Lambda functions. There are many ways of interacting with AWS Lambda. We refer the reader to the official AWS documentation to learn more about them.
The examples below were created for demonstrative purposes. They will likely require modification to suit the specific needs of your project. Be sure to fully understand all of the details of the code and the associated service costs before adapting the type of solution we propose. Importantly, the code we will share has not undergone rigorous testing. Any solution that includes creation and invocation of multiple Lambda functions and Amazon CloudWatch alarms (as described here) requires appropriate validation to prevent the accumulation of redundant/orphan artifacts.
We highly advise that you verify the details of this post against the most up-to-date AWS Lambda documentation and most up-to-date versions of the supporting libraries.
Enforcing Developer Compliance
While cloud governance is often vital for successful and efficient use of cloud services, its enforcement can sometimes be challenging. For example: Amazon SageMaker includes an API for appending tags to training jobs. These can be used to include metadata associated with the SageMaker job such as the name of the training project, the stage of development, the goal of the current trial, the name of the development group or user running the job, etc. This metadata can be used to collect statistics such as the cost of development per project or group. In the code block below, we demonstrate the application of several tags to a SageMaker training job:
from sagemaker.pytorch import PyTorch
tags = [{'Key': 'username', 'Value': 'johndoe'},\
{'Key': 'model_name', 'Value': 'mnist'},\
{'Key': 'training_phase', 'Value': 'finetune'},\
{'Key': 'description', 'Value': 'fine tune final linear layer'}]
# define the training job with tags\
estimator = PyTorch(\
entry_point='train.py',\
framework_version='2.1.0',\
role='<arn role>',\
py_version='py310',\
job_name='demo',\
instance_type='ml.g5.xlarge',\
instance_count=1,\
tags=tags\
)
# deploy the job to the cloud\
estimator.fit()
Naturally, these tags are only helpful if we can enforce their application. This is where AWS Lambda comes to the rescue. Using Amazon EventBridge we can monitor changes in the status of a SageMaker training jobs and register a function that will be triggered on every change. In the code block below, we propose a Python routine that will verify the presence of specific SageMaker tags every time a job is started. In case a tag is missing the job is automatically terminated. The structure of the event is documented here. Note the use of (the more detailed) SecondaryStatus field to poll the status of the training job (rather than TrainingJobStatus).
import boto3\
def stop_training_job(training_job_name):\
sm_client = boto3.client("sagemaker")\
response = sm_client.stop_training_job(TrainingJobName=training_job_name)\
assert response['ResponseMetadata']['HTTPStatusCode'] == 200\
# TODO - optionally send an email notification
def enforce_required_tags(training_job_name, event):\
event_tags = event['detail']['Tags']\
if 'model_name' not in event_tags:\
stop_training_job(training_job_name)
# define lambda handler\
def sagemaker_event_handler(event, _):\
job_name = event['detail']['TrainingJobName']\
job_secondary_status = event['detail']['SecondaryStatus']\
if job_secondary_status == 'Starting':\
enforce_required_tags(job_name, event)
AWS offers multiple ways for creating a Lambda function. Please see the AWS Lambda documentation for details. Once created, make sure to set the function as the target of the EventBridge rule.
The same function can be used to enforce additional development rules that are aimed at controlling cost such as: the types of instances that can be used, the maximum number of instances per job, the maximum runtime of a job, and more.
Stopping Stalled Experiments
Imagine the following scenario: You have planned a large cloud-based training job that will run on eight $30-an-hour ML compute instances for a period of three days. For the purpose of this task, you have secured a budget of $17,280 (8 instances x $30 an hour x 24 hours x 3 days). You start up the training job just before heading out for a three-day holiday weekend. When you return from your holiday weekend, you discover that an hour into the job, the training process stalled causing the expensive machinery to essentially remain completely idle for three long days. Not only have you wasted $17,280 (good luck explaining that to your boss) but your development has now been pushed back by three days!!
One way to protect yourself against this type of occurrence, is to monitor the utilization of the underlying training job resources. For example, if the GPU utilization your training instances remains below a certain threshold for an extended period of time, this is likely to be a sign that something has gone wrong and that the training job should be stopped immediately.
We will do this by defining an Amazon CloudWatch alarm that monitors the GPU utilization of one of the training instances of each SageMaker job and invokes an AWS Lambda function that terminates the job if the alarm is triggered. Setting this up requires three components: an Amazon CloudWatch alarm (one per training job), an AWS Lambda function, and an Amazon Simple Notification Service (SNS) topic that is used to link the Lambda function to the CloudWatch alarms.
First, we create an SNS topic. This can be done via the Amazon SNS Console or in Python, as shown below:
import boto3
sns_client = boto3.client('sns')\
# Create a SNS notification topic.\
topic = sns_client.create_topic(Name="SageMakerTrainingJobIdleTopic")\
topic_arn = topic.arn\
print(f"Created SNS topic arn: {topic_arn}")
Next, we extend the sagemaker_event_handler function we defined above to create a unique alarm each time a training job is started. We program the alarm to measure the average GPU utilization over five-minute periods and to alert our SNS topic when there are three consecutive measurements below 1%. The alarm is deleted when the job is completed.
def create_training_alarm(job_name):\
topic_arn = '<sns topic arn>'
SAMPLE_PERIOD_SECONDS = 60 * 5 # 5 minutes\
SAMPLE_POINTS_LIMIT = 3\
GPU_UTIL_THRESHOLD_PERCENTAGE = 1
cloudwatch_client = boto3.client('cloudwatch')
# A new sample is generated each SAMPLE_PERIOD_SECONDS seconds.\
# The alarm will set off it there will be more than SAMPLE_POINTS_LIMIT\
# below the limit.\
response = cloudwatch_client.put_metric_alarm(\
AlarmName=job_name + 'GPUUtil',\
AlarmActions=topic_arn,\
MetricName='GPUUtilization',\
Namespace='/aws/sagemaker/TrainingJobs',\
Statistic='Average',\
Dimensions=[{\
"Name": "Host",\
"Value": job_name+"/algo-1"\
}],\
Period=SAMPLE_PERIOD_SECONDS,\
EvaluationPeriods=SAMPLE_POINTS_LIMIT,\
DatapointsToAlarm=SAMPLE_POINTS_LIMIT,\
Threshold=GPU_UTIL_THRESHOLD_PERCENTAGE,\
ComparisonOperator='LessThanOrEqualToThreshold',\
TreatMissingData='notBreaching'\
)\
assert response['ResponseMetadata']['HTTPStatusCode'] == 200
def delete_training_alarm(job_name):\
cloudwatch_client = boto3.client('cloudwatch')\
response = cloudwatch_client.delete_alarms(\
AlarmNames=[job_name+'GPUUtil'])
def sagemaker_event_handler(event, _):\
job_name = event['detail']['TrainingJobName']\
job_secondary_status = event['detail']['SecondaryStatus']\
if job_secondary_status == 'Starting':\
enforce_required_tags(job_name, event)\
elif job_secondary_status == 'Training':\
create_training_alarm(job_name)\
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:\
delete_training_alarm(job_name)
Last, we define a second Python AWS Lambda function that parses messages received from the SNS topic and terminates the training job associated with the alarm.
import boto3, json
def lambda_sns_handler(event, context):\
data = json.loads(event['Records'][0]['Sns']['Message'])\
alarm_name = data['AlarmName']\
training_job_name = alarm_name.replace('GPUUtil', '')\
stop_training_job(training_job_name)
AWS offers multiple mechanisms for subscribing a Lambda function to an SNS topic including the AWS Console, AWS CLI, and the AWS Serverless Application Model (AWS SAM).
The solution we described is summarized in the following diagram:
Note that the same architecture can be used to enforce a minimum level of GPU utilization of your ML training projects. GPUs are typically the most expensive resource in your training infrastructure and your goal should be to maximize the utilization of all of your training workloads. By dictating a minimum level of utilization (e.g. 80%) you can ensure that all developers optimize their workloads appropriately.
Ensuring Continuity of Development
In our previous example, we demonstrated how to identify and stop a stalled experiment. In the large training job scenario that we described, this helped save a lot of money, but it did not address the three day delay to development. Obviously, if the source of the stall is in your code, it makes sense to postpone resuming training until the problem is fixed. However, we often encounter training interruptions that are not caused by our code but rather by sporadic failures in the service environment. In such scenarios, your priority may be to ensure training continuity rather than having to wait for someone to manually resume the training job (using the most recent training checkpoint). In the code block below, we use the boto3 create_training_job API to extend our sagemaker_event_handler function to (naively) resume any training job that has failed after running for at least two hours.
import boto3, datetime
def clone_job(training_name, disable_spot=False):\
# get description\
client = boto3.client('sagemaker')\
desc = client.describe_training_job(TrainingJobName=training_name)
# update the training name\
new_training_name = training_name + 'clone'
use_spots = (not disable_spot) and desc["EnableManagedSpotTraining"]
if disable_spot:\
desc["StoppingCondition"].pop("MaxWaitTimeInSeconds", None)
client.create_training_job(\
TrainingJobName=new_training_name,\
HyperParameters=desc["HyperParameters"],\
AlgorithmSpecification=desc["AlgorithmSpecification"],\
RoleArn=desc["RoleArn"],\
OutputDataConfig=desc["OutputDataConfig"],\
ResourceConfig=desc["ResourceConfig"],\
StoppingCondition=desc["StoppingCondition"],\
EnableNetworkIsolation=desc["EnableNetworkIsolation"],\
EnableInterContainerTrafficEncryption=desc[\
"EnableInterContainerTrafficEncryption"\
],\
EnableManagedSpotTraining=use_spots,\
Tags=client.list_tags(ResourceArn=desc['TrainingJobArn'])\
)
def sagemaker_event_handler(event, _):\
TRAIN_TIME_THRESHOLD = 2 * 60 * 60: # 2 hours\
job_name = event['detail']['TrainingJobName']\
job_secondary_status = event['detail']['SecondaryStatus']\
if job_secondary_status == 'Starting':\
enforce_required_tags(job_name, event)\
elif job_secondary_status == 'Training':\
create_training_alarm(job_name)\
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:\
delete_training_alarm(job_name)
if job_secondary_status == 'Failed':\
start_time = datetime.datetime.utcfromtimestamp(\
event['detail']['CreationTime']/1000)\
end_time = datetime.datetime.utcfromtimestamp(\
event['detail']['TrainingEndTime']/1000)\
training_time_seconds = (end_time - start_time).seconds\
if training_time_seconds >= TRAIN_TIME_THRESHOLD:\
clone_job(job_name)
The function above automatically resumes any job that fails after two hours. A more practical solution might attempt to diagnose the type of error to determine whether resuming the job would be appropriate. One way to do this is to parse the failure description message and/or the CloudWatch logs associated with the failing job.
Advanced Spot-instance Utilization
One of the compelling features of Amazon SageMaker is its support for managed spot training. Amazon EC2 Spot Instances allow you to take advantage of unused EC2 capacity at discounted prices. The catch is that these instances can be taken away ("interrupted") in the middle of their use. Thus, Spot instances should be used only for fault-tolerant workloads. SageMaker makes it easy to take advantage of Spot instances by identifying Spot interruptions on your behalf and automatically restarting jobs when new Spot instances become available. While managed spot instances can be used to reduce cost of training, sometimes this strategy can backfire. For example, when there is low spot capacity your training jobs might time out before starting. Alternatively, the job might experience frequent interruptions that prevent it from making any meaningful progress. Both occurrences can interfere with development and reduce productivity. These types of situations can be monitored and addressed using AWS Lambda. In the code block below, we extend our sagemaker_event_handler function to identify a training job that has been interrupted more than three times and replace it with a cloned job in which the managed spot training is disabled.
def sagemaker_event_handler(event, _):\
TRAIN_TIME_THRESHOLD = 2 * 60 * 60: # 2 hours\
MIN_ITERRUPTS = 3\
job_name = event['detail']['TrainingJobName']\
job_secondary_status = event['detail']['SecondaryStatus']\
if job_secondary_status == 'Starting':\
enforce_required_tags(job_name, event)\
elif job_secondary_status == 'Training':\
create_training_alarm(job_name)\
elif job_secondary_status in ['Completed', 'Failed', 'Stopped']:\
delete_training_alarm(job_name)
if job_secondary_status == 'Failed':\
start_time = datetime.datetime.utcfromtimestamp(\
event['detail']['CreationTime']/1000)\
end_time = datetime.datetime.utcfromtimestamp(\
event['detail']['TrainingEndTime']/1000)\
training_time_seconds = (end_time - start_time).seconds\
if training_time_seconds >= TRAIN_TIME_THRESHOLD:\
clone_job(job_name)
if job_secondary_status == 'Interrupted':\
transitions = event['detail']["SecondaryStatusTransitions"]\
interrupts = [e for e in transitions if e["Status"] == "Interrupted"]\
num_interrupts = len(interrupts)\
if num_interrupts > MIN_ITERRUPTS:\
stop_training_job(job_name)\
clone_job(job_name, disable_spot=True)
The implementation above determined the spot usage strategy based solely on the number of interruptions of the training job in question. A more elaborate solution might take into account other jobs (that use the same instance types), the duration of time across which the interruptions occurred, the amount of active training time, and/or the number of recent jobs that timed out due to low Spot instance capacity.
Summary
Effective AI model development requires the definition of a creative and detailed training infrastructure architecture in order to minimize cost and maximize productivity. In this post we have demonstrated how serverless AWS Lambda functions can be used to augment Amazon SageMaker's managed training service in order to address some common issues that can occur during training. Naturally, the precise manner in which you might apply these kinds of techniques will depend greatly on the specifics of your project.
Please feel free to reach out with questions, comments, and corrections. Be sure to check out our other posts on the topic of DL training optimization.
Top comments (0)