DEV Community

Petter Gustafsson
Petter Gustafsson

Posted on

Sagemaker model CI/CD

Intro

On our journey to build cheap, scalable GPU inference at Tonar we have selected to go with Sagemaker Async Inference. We manage our infrastructure with Terraform and this is a small how-to on actually getting CI/CD to work when you are continuously updating models behind Sagemaker Endpoints. We will use a custom Whisper model as the base reference here.

Expected behaviour...

Reading the documentation of both AWS and Terraform you would end up with something like the below code to have an autoscaling endpoint.

resource "aws_sagemaker_model" "whisper" {
  execution_role_arn = aws_iam_role.whisper.arn
  name               = "whisper"

  primary_container {
    image          = data.aws_ssm_parameter.whisper_digest_id.value
    model_data_url = "s3://..."
    environment = {
      STAGE                 = terraform.workspace
      CONTAINER_ID          = data.aws_ssm_parameter.whisper_digest_id.value
      SENTRY_DSN            = data.aws_ssm_parameter.sentry_dsn.value
      SENTRY_ENVIRONMENT    = terraform.workspace
      INSTANCE_TYPE         = "ml.g4dn.xlarge"
    }
  }
}

resource "aws_sagemaker_endpoint_configuration" "whisper" {
  production_variants {
    model_name                                        = aws_sagemaker_model.whisper.name
    variant_name                                      = "AllTraffic"
    initial_instance_count                            = 1
    instance_type                                     = "ml.g4dn.xlarge"
    container_startup_health_check_timeout_in_seconds = 1800
    model_data_download_timeout_in_seconds            = 1200
  }

  async_inference_config {
    output_config {
      s3_output_path  = "s3://..."
      s3_failure_path = "s3://..."
      notification_config {
        include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"]
        success_topic                 = aws_sns_topic.whisper_success_topic.arn
        error_topic                   = aws_sns_topic.whisper_error_topic.arn
      }
    }
    client_config {
      max_concurrent_invocations_per_instance = 1
    }
  }
}


resource "aws_sagemaker_endpoint" "whisper" {
  name                 = "whisper"
  endpoint_config_name = aws_sagemaker_endpoint_configuration.whisper.name
  deployment_config {
    rolling_update_policy {
      maximum_batch_size {
        type  = "CAPACITY_PERCENT"
        value = 50
      }
      maximum_execution_timeout_in_seconds = 900
      wait_interval_in_seconds             = 180
    }
  }
  depends_on = [aws_sagemaker_model.whisper, aws_sagemaker_endpoint_configuration.whisper]
}

resource "aws_appautoscaling_target" "sagemaker_target" {
  max_capacity       = 30
  min_capacity       = 0
  resource_id        = "endpoint/${aws_sagemaker_endpoint.whisper.name}/variant/AllTraffic"
  role_arn           = aws_iam_role.whisper.arn
  scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
  service_namespace  = "sagemaker"
  depends_on         = [aws_sagemaker_endpoint.whisper, aws_sagemaker_endpoint_configuration.whisper]
}

resource "aws_appautoscaling_policy" "sagemaker_policy_regular" {
  name               = "whisper-invocations-scaling-policy"
  policy_type        = "TargetTrackingScaling"
  resource_id        = aws_appautoscaling_target.sagemaker_target.resource_id
  scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
  service_namespace  = aws_appautoscaling_target.sagemaker_target.service_namespace

  target_tracking_scaling_policy_configuration {

    customized_metric_specification {
      metric_name = "ApproximateBacklogSizePerInstance"
      namespace   = "AWS/SageMaker"
      dimensions {
        name  = "EndpointName"
        value = aws_sagemaker_endpoint.whisper.name
      }
      statistic = "Average"
    }
    target_value       = 10
    scale_in_cooldown  = 30
    scale_out_cooldown = 120
  }
}

// Scales from 0 to 1 without waiting for queue to fill up
resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" {
  name               = "whisper-backlog-without-capacity-scaling-policy"
  policy_type        = "StepScaling"
  resource_id        = aws_appautoscaling_target.sagemaker_target.resource_id
  scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
  service_namespace  = aws_appautoscaling_target.sagemaker_target.service_namespace
  step_scaling_policy_configuration {
    adjustment_type         = "ChangeInCapacity"
    cooldown                = 300
    metric_aggregation_type = "Average"
    step_adjustment {
      metric_interval_lower_bound = 0
      scaling_adjustment          = 1
    }
  }
}

resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" {
  alarm_name          = "whisper-backlog-without-capacity-scaling-policy"
  metric_name         = "HasBacklogWithoutCapacity"
  namespace           = "AWS/SageMaker"
  statistic           = "Average"
  evaluation_periods  = 2
  datapoints_to_alarm = 2
  comparison_operator = "GreaterThanOrEqualToThreshold"
  threshold           = 1
  treat_missing_data  = "missing"
  dimensions = {
    EndpointName = aws_sagemaker_endpoint.whisper.name
  }
  period            = 60
  alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up"
  alarm_actions     = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn]
}
Enter fullscreen mode Exit fullscreen mode

The problem

Why this is not working has to do with the inner workings of Terraform not fully understanding the update-chain that must happen to satisfy Sagemaker. Basically updating the model will result in one of two scenarios:

  • The model and endpoint configuration is updated, but the changes aren't applied to the endpoint, so you have to destroy the endpoint in order to apply the changes (not good for production...)
  • The endpoint apply fails because it's always referring to the previous version of the endpoint configuration (which terraform has destroyed), which leads to deleting the endpoint again...

The dirty solution

So if you really want to just be able to update the underlying code of the model (like any other container), push it to ECR and expect it to roll out to your endpoint, this is the only solution I've come up with so far.

First we will change the terraform code to:

resource "aws_sagemaker_model" "whisper" {
  execution_role_arn = aws_iam_role.whisper.arn
  primary_container {
    image          = data.aws_ssm_parameter.whisper_digest_id.value
    model_data_url = "s3://..."
    environment = {
      STAGE                 = terraform.workspace
      CONTAINER_ID          = data.aws_ssm_parameter.whisper_digest_id.value
      SENTRY_DSN            = data.aws_ssm_parameter.sentry_dsn.value
      SENTRY_ENVIRONMENT    = terraform.workspace
      INSTANCE_TYPE         = "ml.g4dn.xlarge"
    }
  }
}

resource "aws_cloudwatch_metric_alarm" "sagemaker_endpoint_error_rate" {
  alarm_name          = "EndToEndDeploymentHighErrorRateAlarm"
  alarm_description   = "Monitors the error rate of 4xx errors"
  metric_name         = "Invocation4XXErrors"
  namespace           = "AWS/SageMaker"
  statistic           = "Average"
  evaluation_periods  = 2
  comparison_operator = "GreaterThanThreshold"
  threshold           = 1
  period              = 600
  treat_missing_data  = "notBreaching"
  dimensions = {
    EndpointName = "whisper"
    VariantName  = "AllTraffic"
  }
}

data "external" "deploy_model" {
  program = ["python", "${path.module}/deploy_model.py"]

  query = {
    deploy_action                 = var.DEPLOY_ACTION
    aws_access_key_id             = var.ACCESS_KEY
    aws_secret_access_key         = var.SECRET_KEY
    endpoint_name                 = "whisper"
    endpoint_alarm_name           = "EndToEndDeploymentHighErrorRateAlarm"
    endpoint_config_model_name    = aws_sagemaker_model.whisper.name
    endpoint_config_model_image   = data.aws_ssm_parameter.whisper_digest_id.value
    endpoint_config_instance_type = "ml.g4dn.xlarge"
    endpoint_config_output_path   = "s3://..."
    endpoint_config_error_path    = "s3://..."
    endpoint_config_success_topic = aws_sns_topic.whisper_success_topic.arn
    endpoint_config_error_topic   = aws_sns_topic.whisper_error_topic.arn
  }
  depends_on = [aws_sagemaker_model.whisper]
}

output "model_deployment" {
  description = "Model deployment"
  value       = data.external.deploy_model.result
}

resource "aws_appautoscaling_target" "sagemaker_target" {
  max_capacity       = 30
  min_capacity       = 0
  resource_id        = "endpoint/${data.external.deploy_model.result["endpoint_name"]}/variant/AllTraffic"
  role_arn           = aws_iam_role.whisper.arn
  scalable_dimension = "sagemaker:variant:DesiredInstanceCount"
  service_namespace  = "sagemaker"
  depends_on         = [data.external.deploy_model]
}

resource "aws_appautoscaling_policy" "sagemaker_policy_regular" {
  name               = "whisper-invocations-scaling-policy"
  policy_type        = "TargetTrackingScaling"
  resource_id        = aws_appautoscaling_target.sagemaker_target.resource_id
  scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
  service_namespace  = aws_appautoscaling_target.sagemaker_target.service_namespace

  target_tracking_scaling_policy_configuration {

    customized_metric_specification {
      metric_name = "ApproximateBacklogSizePerInstance"
      namespace   = "AWS/SageMaker"
      dimensions {
        name  = "EndpointName"
        value = data.external.deploy_model.result["endpoint_name"]
      }
      statistic = "Average"
    }
    target_value       = 10
    scale_in_cooldown  = 30
    scale_out_cooldown = 120
  }
}

// Scales from 0 to 1 without waiting for queue to fill up
resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" {
  name               = "whisper-backlog-without-capacity-scaling-policy"
  policy_type        = "StepScaling"
  resource_id        = aws_appautoscaling_target.sagemaker_target.resource_id
  scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension
  service_namespace  = aws_appautoscaling_target.sagemaker_target.service_namespace
  step_scaling_policy_configuration {
    adjustment_type         = "ChangeInCapacity"
    cooldown                = 300
    metric_aggregation_type = "Average"
    step_adjustment {
      metric_interval_lower_bound = 0
      scaling_adjustment          = 1
    }
  }
}

resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" {
  alarm_name          = "whisper-backlog-without-capacity-scaling-policy"
  metric_name         = "HasBacklogWithoutCapacity"
  namespace           = "AWS/SageMaker"
  statistic           = "Average"
  evaluation_periods  = 2
  datapoints_to_alarm = 2
  comparison_operator = "GreaterThanOrEqualToThreshold"
  threshold           = 1
  treat_missing_data  = "missing"
  dimensions = {
    EndpointName = data.external.deploy_model.result["endpoint_name"]
  }
  period            = 60
  alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up"
  alarm_actions     = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn]
}
Enter fullscreen mode Exit fullscreen mode

Basically we are moving the entire management of the CRUD of the endpoint/config/model to a Python script that runs boto3. In this case I've hard coded some values as you can see below, and some values I keep as parameters passed into the script. For this scenario I'm also passing in the terraform action (plan, apply etc.) since I don't want to run the script if it's not an apply action. Maybe not the best way, but it fits well with our Github actions.

import boto3
import json
import sys
from datetime import datetime
from time import sleep


def endpoint_exists(client, endpoint_name: str) -> tuple[str | None, str | None]:
    try:
        res = client.describe_endpoint(EndpointName=endpoint_name)
        return (
            res["EndpointConfigName"],
            res["ProductionVariants"][0]["DeployedImages"][0]["SpecifiedImage"],
        )
    except Exception:
        return None, None


def create_endpoint_config(
    client,
    name: str,
    model_name: str,
    instance_type: str,
    output_path: str,
    error_path: str,
    success_topic: str,
    error_topic: str,
):
    config_name = f"{name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
    res = client.create_endpoint_config(
        EndpointConfigName=config_name,
        ProductionVariants=[
            {
                "ModelName": model_name,
                "VariantName": "AllTraffic",
                "InstanceType": instance_type,
                "InitialInstanceCount": 1,
                "ContainerStartupHealthCheckTimeoutInSeconds": 1800,
                "ModelDataDownloadTimeoutInSeconds": 1200,
            }
        ],
        AsyncInferenceConfig={
            "OutputConfig": {
                "S3OutputPath": output_path,
                "S3FailurePath": error_path,
                "NotificationConfig": {
                    "IncludeInferenceResponseIn": [
                        "SUCCESS_NOTIFICATION_TOPIC",
                        "ERROR_NOTIFICATION_TOPIC",
                    ],
                    "SuccessTopic": success_topic,
                    "ErrorTopic": error_topic,
                },
            },
            "ClientConfig": {
                "MaxConcurrentInvocationsPerInstance": 1,
            },
        },
    )
    return config_name, res["EndpointConfigArn"]


def create_endpoint(
    client, endpoint_name: str, endpoint_config_name: str, alarm_name: str
):
    client.create_endpoint(
        EndpointName=endpoint_name,
        EndpointConfigName=endpoint_config_name,
        DeploymentConfig={
            "BlueGreenUpdatePolicy": {
                "TrafficRoutingConfiguration": {
                    "Type": "ALL_AT_ONCE",
                    "WaitIntervalInSeconds": 0,
                },
            },
            "AutoRollbackConfiguration": {
                "Alarms": [
                    {"AlarmName": alarm_name},
                ]
            },
        },
    )
    waiter = client.get_waiter("endpoint_in_service")
    waiter.wait(
        EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 60}
    )


def delete_endpoint_config(client, endpoint_config_name: str):
    client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)


def update_endpoint(
    client, endpoint_name: str, endpoint_config_name: str, alarm_name: str
):
    retry_delay = 10
    for _ in range(60):
        try:
            client.update_endpoint(
                EndpointName=endpoint_name,
                EndpointConfigName=endpoint_config_name,
                RetainAllVariantProperties=True,
                RetainDeploymentConfig=True,
            )
            break
        except Exception as e:
            if "Cannot update in-progress endpoint" in str(e):
                sleep(retry_delay)
            else:
                raise e
    waiter = client.get_waiter("endpoint_in_service")
    waiter.wait(
        EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 120}
    )


def main(input: dict):
    deploy_action = input["deploy_action"]
    if deploy_action != "apply":
        print(
            json.dumps({"type": "no_change", "endpoint_name": input["endpoint_name"]})
        )
        return
    session = boto3.Session(
        aws_access_key_id=input["aws_access_key_id"],
        aws_secret_access_key=input["aws_secret_access_key"],
    )
    ec_model_name = input["endpoint_config_model_name"]
    ec_model_image = input["endpoint_config_model_image"]
    ec_instance_type = input["endpoint_config_instance_type"]
    ec_output_path = input["endpoint_config_output_path"]
    ec_error_path = input["endpoint_config_error_path"]
    ec_success_topic = input["endpoint_config_success_topic"]
    ec_error_topic = input["endpoint_config_error_topic"]
    endpoint_name = input["endpoint_name"]
    endpoint_alarm_name = input["endpoint_alarm_name"]

    client = session.client("sagemaker", region_name="eu-north-1")
    config_name, image_name = endpoint_exists(client, endpoint_name)
    if config_name:
        if ec_model_image == image_name:
            print(json.dumps({"type": "no_change", "endpoint_name": endpoint_name}))
        else:
            new_config_name, config_arn = create_endpoint_config(
                client,
                endpoint_name,
                ec_model_name,
                ec_instance_type,
                ec_output_path,
                ec_error_path,
                ec_success_topic,
                ec_error_topic,
            )
            update_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name)
            delete_endpoint_config(client, config_name)
            print(
                json.dumps(
                    {
                        "type": "update",
                        "endpoint_name": endpoint_name,
                        "endpoint_config_name": new_config_name,
                        "model_name": ec_model_name,
                        "old_endpoint_config_name": config_name,
                        "endpoint_config_arn": config_arn,
                    }
                )
            )
    else:
        new_config_name, config_arn = create_endpoint_config(
            client,
            endpoint_name,
            ec_model_name,
            ec_instance_type,
            ec_output_path,
            ec_error_path,
            ec_success_topic,
            ec_error_topic,
        )
        create_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name)
        print(
            json.dumps(
                {
                    "type": "new",
                    "endpoint_name": endpoint_name,
                    "endpoint_config_name": new_config_name,
                    "endpoint_config_arn": config_arn,
                }
            )
        )


if __name__ == "__main__":
    input = sys.stdin.read()
    input_json = json.loads(input)
    main(input_json)
Enter fullscreen mode Exit fullscreen mode

Conclusion

It's unfortunate that this has been a forum problem for >4 years with no fix. But at least this solution allows you to correctly handle the Sagemaker logic for replacing the underlying containers in a safe way with progressive rollout and rollback.

Top comments (0)