Intro
On our journey to build cheap, scalable GPU inference at Tonar we have selected to go with Sagemaker Async Inference. We manage our infrastructure with Terraform and this is a small how-to on actually getting CI/CD to work when you are continuously updating models behind Sagemaker Endpoints. We will use a custom Whisper model as the base reference here.
Expected behaviour...
Reading the documentation of both AWS and Terraform you would end up with something like the below code to have an autoscaling endpoint.
resource "aws_sagemaker_model" "whisper" {
execution_role_arn = aws_iam_role.whisper.arn
name = "whisper"
primary_container {
image = data.aws_ssm_parameter.whisper_digest_id.value
model_data_url = "s3://..."
environment = {
STAGE = terraform.workspace
CONTAINER_ID = data.aws_ssm_parameter.whisper_digest_id.value
SENTRY_DSN = data.aws_ssm_parameter.sentry_dsn.value
SENTRY_ENVIRONMENT = terraform.workspace
INSTANCE_TYPE = "ml.g4dn.xlarge"
}
}
}
resource "aws_sagemaker_endpoint_configuration" "whisper" {
production_variants {
model_name = aws_sagemaker_model.whisper.name
variant_name = "AllTraffic"
initial_instance_count = 1
instance_type = "ml.g4dn.xlarge"
container_startup_health_check_timeout_in_seconds = 1800
model_data_download_timeout_in_seconds = 1200
}
async_inference_config {
output_config {
s3_output_path = "s3://..."
s3_failure_path = "s3://..."
notification_config {
include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"]
success_topic = aws_sns_topic.whisper_success_topic.arn
error_topic = aws_sns_topic.whisper_error_topic.arn
}
}
client_config {
max_concurrent_invocations_per_instance = 1
}
}
}
resource "aws_sagemaker_endpoint" "whisper" {
name = "whisper"
endpoint_config_name = aws_sagemaker_endpoint_configuration.whisper.name
deployment_config {
rolling_update_policy {
maximum_batch_size {
type = "CAPACITY_PERCENT"
value = 50
}
maximum_execution_timeout_in_seconds = 900
wait_interval_in_seconds = 180
}
}
depends_on = [aws_sagemaker_model.whisper, aws_sagemaker_endpoint_configuration.whisper]
}
resource "aws_appautoscaling_target" "sagemaker_target" {
max_capacity = 30
min_capacity = 0
resource_id = "endpoint/${aws_sagemaker_endpoint.whisper.name}/variant/AllTraffic"
role_arn = aws_iam_role.whisper.arn
scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
service_namespace = "sagemaker"
depends_on = [aws_sagemaker_endpoint.whisper, aws_sagemaker_endpoint_configuration.whisper]
}
resource "aws_appautoscaling_policy" "sagemaker_policy_regular" {
name = "whisper-invocations-scaling-policy"
policy_type = "TargetTrackingScaling"
resource_id = aws_appautoscaling_target.sagemaker_target.resource_id
scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace
target_tracking_scaling_policy_configuration {
customized_metric_specification {
metric_name = "ApproximateBacklogSizePerInstance"
namespace = "AWS/SageMaker"
dimensions {
name = "EndpointName"
value = aws_sagemaker_endpoint.whisper.name
}
statistic = "Average"
}
target_value = 10
scale_in_cooldown = 30
scale_out_cooldown = 120
}
}
// Scales from 0 to 1 without waiting for queue to fill up
resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" {
name = "whisper-backlog-without-capacity-scaling-policy"
policy_type = "StepScaling"
resource_id = aws_appautoscaling_target.sagemaker_target.resource_id
scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace
step_scaling_policy_configuration {
adjustment_type = "ChangeInCapacity"
cooldown = 300
metric_aggregation_type = "Average"
step_adjustment {
metric_interval_lower_bound = 0
scaling_adjustment = 1
}
}
}
resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" {
alarm_name = "whisper-backlog-without-capacity-scaling-policy"
metric_name = "HasBacklogWithoutCapacity"
namespace = "AWS/SageMaker"
statistic = "Average"
evaluation_periods = 2
datapoints_to_alarm = 2
comparison_operator = "GreaterThanOrEqualToThreshold"
threshold = 1
treat_missing_data = "missing"
dimensions = {
EndpointName = aws_sagemaker_endpoint.whisper.name
}
period = 60
alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up"
alarm_actions = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn]
}
The problem
Why this is not working has to do with the inner workings of Terraform not fully understanding the update-chain that must happen to satisfy Sagemaker. Basically updating the model will result in one of two scenarios:
- The model and endpoint configuration is updated, but the changes aren't applied to the endpoint, so you have to destroy the endpoint in order to apply the changes (not good for production...)
- The endpoint apply fails because it's always referring to the previous version of the endpoint configuration (which terraform has destroyed), which leads to deleting the endpoint again...
The dirty solution
So if you really want to just be able to update the underlying code of the model (like any other container), push it to ECR and expect it to roll out to your endpoint, this is the only solution I've come up with so far.
First we will change the terraform code to:
resource "aws_sagemaker_model" "whisper" {
execution_role_arn = aws_iam_role.whisper.arn
primary_container {
image = data.aws_ssm_parameter.whisper_digest_id.value
model_data_url = "s3://..."
environment = {
STAGE = terraform.workspace
CONTAINER_ID = data.aws_ssm_parameter.whisper_digest_id.value
SENTRY_DSN = data.aws_ssm_parameter.sentry_dsn.value
SENTRY_ENVIRONMENT = terraform.workspace
INSTANCE_TYPE = "ml.g4dn.xlarge"
}
}
}
resource "aws_cloudwatch_metric_alarm" "sagemaker_endpoint_error_rate" {
alarm_name = "EndToEndDeploymentHighErrorRateAlarm"
alarm_description = "Monitors the error rate of 4xx errors"
metric_name = "Invocation4XXErrors"
namespace = "AWS/SageMaker"
statistic = "Average"
evaluation_periods = 2
comparison_operator = "GreaterThanThreshold"
threshold = 1
period = 600
treat_missing_data = "notBreaching"
dimensions = {
EndpointName = "whisper"
VariantName = "AllTraffic"
}
}
data "external" "deploy_model" {
program = ["python", "${path.module}/deploy_model.py"]
query = {
deploy_action = var.DEPLOY_ACTION
aws_access_key_id = var.ACCESS_KEY
aws_secret_access_key = var.SECRET_KEY
endpoint_name = "whisper"
endpoint_alarm_name = "EndToEndDeploymentHighErrorRateAlarm"
endpoint_config_model_name = aws_sagemaker_model.whisper.name
endpoint_config_model_image = data.aws_ssm_parameter.whisper_digest_id.value
endpoint_config_instance_type = "ml.g4dn.xlarge"
endpoint_config_output_path = "s3://..."
endpoint_config_error_path = "s3://..."
endpoint_config_success_topic = aws_sns_topic.whisper_success_topic.arn
endpoint_config_error_topic = aws_sns_topic.whisper_error_topic.arn
}
depends_on = [aws_sagemaker_model.whisper]
}
output "model_deployment" {
description = "Model deployment"
value = data.external.deploy_model.result
}
resource "aws_appautoscaling_target" "sagemaker_target" {
max_capacity = 30
min_capacity = 0
resource_id = "endpoint/${data.external.deploy_model.result["endpoint_name"]}/variant/AllTraffic"
role_arn = aws_iam_role.whisper.arn
scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
service_namespace = "sagemaker"
depends_on = [data.external.deploy_model]
}
resource "aws_appautoscaling_policy" "sagemaker_policy_regular" {
name = "whisper-invocations-scaling-policy"
policy_type = "TargetTrackingScaling"
resource_id = aws_appautoscaling_target.sagemaker_target.resource_id
scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace
target_tracking_scaling_policy_configuration {
customized_metric_specification {
metric_name = "ApproximateBacklogSizePerInstance"
namespace = "AWS/SageMaker"
dimensions {
name = "EndpointName"
value = data.external.deploy_model.result["endpoint_name"]
}
statistic = "Average"
}
target_value = 10
scale_in_cooldown = 30
scale_out_cooldown = 120
}
}
// Scales from 0 to 1 without waiting for queue to fill up
resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" {
name = "whisper-backlog-without-capacity-scaling-policy"
policy_type = "StepScaling"
resource_id = aws_appautoscaling_target.sagemaker_target.resource_id
scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace
step_scaling_policy_configuration {
adjustment_type = "ChangeInCapacity"
cooldown = 300
metric_aggregation_type = "Average"
step_adjustment {
metric_interval_lower_bound = 0
scaling_adjustment = 1
}
}
}
resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" {
alarm_name = "whisper-backlog-without-capacity-scaling-policy"
metric_name = "HasBacklogWithoutCapacity"
namespace = "AWS/SageMaker"
statistic = "Average"
evaluation_periods = 2
datapoints_to_alarm = 2
comparison_operator = "GreaterThanOrEqualToThreshold"
threshold = 1
treat_missing_data = "missing"
dimensions = {
EndpointName = data.external.deploy_model.result["endpoint_name"]
}
period = 60
alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up"
alarm_actions = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn]
}
Basically we are moving the entire management of the CRUD of the endpoint/config/model to a Python script that runs boto3. In this case I've hard coded some values as you can see below, and some values I keep as parameters passed into the script. For this scenario I'm also passing in the terraform action (plan, apply etc.) since I don't want to run the script if it's not an apply action. Maybe not the best way, but it fits well with our Github actions.
import boto3
import json
import sys
from datetime import datetime
from time import sleep
def endpoint_exists(client, endpoint_name: str) -> tuple[str | None, str | None]:
try:
res = client.describe_endpoint(EndpointName=endpoint_name)
return (
res["EndpointConfigName"],
res["ProductionVariants"][0]["DeployedImages"][0]["SpecifiedImage"],
)
except Exception:
return None, None
def create_endpoint_config(
client,
name: str,
model_name: str,
instance_type: str,
output_path: str,
error_path: str,
success_topic: str,
error_topic: str,
):
config_name = f"{name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
res = client.create_endpoint_config(
EndpointConfigName=config_name,
ProductionVariants=[
{
"ModelName": model_name,
"VariantName": "AllTraffic",
"InstanceType": instance_type,
"InitialInstanceCount": 1,
"ContainerStartupHealthCheckTimeoutInSeconds": 1800,
"ModelDataDownloadTimeoutInSeconds": 1200,
}
],
AsyncInferenceConfig={
"OutputConfig": {
"S3OutputPath": output_path,
"S3FailurePath": error_path,
"NotificationConfig": {
"IncludeInferenceResponseIn": [
"SUCCESS_NOTIFICATION_TOPIC",
"ERROR_NOTIFICATION_TOPIC",
],
"SuccessTopic": success_topic,
"ErrorTopic": error_topic,
},
},
"ClientConfig": {
"MaxConcurrentInvocationsPerInstance": 1,
},
},
)
return config_name, res["EndpointConfigArn"]
def create_endpoint(
client, endpoint_name: str, endpoint_config_name: str, alarm_name: str
):
client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
DeploymentConfig={
"BlueGreenUpdatePolicy": {
"TrafficRoutingConfiguration": {
"Type": "ALL_AT_ONCE",
"WaitIntervalInSeconds": 0,
},
},
"AutoRollbackConfiguration": {
"Alarms": [
{"AlarmName": alarm_name},
]
},
},
)
waiter = client.get_waiter("endpoint_in_service")
waiter.wait(
EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 60}
)
def delete_endpoint_config(client, endpoint_config_name: str):
client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
def update_endpoint(
client, endpoint_name: str, endpoint_config_name: str, alarm_name: str
):
retry_delay = 10
for _ in range(60):
try:
client.update_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
RetainAllVariantProperties=True,
RetainDeploymentConfig=True,
)
break
except Exception as e:
if "Cannot update in-progress endpoint" in str(e):
sleep(retry_delay)
else:
raise e
waiter = client.get_waiter("endpoint_in_service")
waiter.wait(
EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 120}
)
def main(input: dict):
deploy_action = input["deploy_action"]
if deploy_action != "apply":
print(
json.dumps({"type": "no_change", "endpoint_name": input["endpoint_name"]})
)
return
session = boto3.Session(
aws_access_key_id=input["aws_access_key_id"],
aws_secret_access_key=input["aws_secret_access_key"],
)
ec_model_name = input["endpoint_config_model_name"]
ec_model_image = input["endpoint_config_model_image"]
ec_instance_type = input["endpoint_config_instance_type"]
ec_output_path = input["endpoint_config_output_path"]
ec_error_path = input["endpoint_config_error_path"]
ec_success_topic = input["endpoint_config_success_topic"]
ec_error_topic = input["endpoint_config_error_topic"]
endpoint_name = input["endpoint_name"]
endpoint_alarm_name = input["endpoint_alarm_name"]
client = session.client("sagemaker", region_name="eu-north-1")
config_name, image_name = endpoint_exists(client, endpoint_name)
if config_name:
if ec_model_image == image_name:
print(json.dumps({"type": "no_change", "endpoint_name": endpoint_name}))
else:
new_config_name, config_arn = create_endpoint_config(
client,
endpoint_name,
ec_model_name,
ec_instance_type,
ec_output_path,
ec_error_path,
ec_success_topic,
ec_error_topic,
)
update_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name)
delete_endpoint_config(client, config_name)
print(
json.dumps(
{
"type": "update",
"endpoint_name": endpoint_name,
"endpoint_config_name": new_config_name,
"model_name": ec_model_name,
"old_endpoint_config_name": config_name,
"endpoint_config_arn": config_arn,
}
)
)
else:
new_config_name, config_arn = create_endpoint_config(
client,
endpoint_name,
ec_model_name,
ec_instance_type,
ec_output_path,
ec_error_path,
ec_success_topic,
ec_error_topic,
)
create_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name)
print(
json.dumps(
{
"type": "new",
"endpoint_name": endpoint_name,
"endpoint_config_name": new_config_name,
"endpoint_config_arn": config_arn,
}
)
)
if __name__ == "__main__":
input = sys.stdin.read()
input_json = json.loads(input)
main(input_json)
Conclusion
It's unfortunate that this has been a forum problem for >4 years with no fix. But at least this solution allows you to correctly handle the Sagemaker logic for replacing the underlying containers in a safe way with progressive rollout and rollback.
Top comments (0)