DEV Community

Kate Vu
Kate Vu

Posted on • Originally published at Medium

Building a Serverless LLM Pipeline with Amazon Bedrock and SageMaker Fine-Tuning using AWS CDK

Large-language models (LLMs) can support a wide range of use cases such as classification, summaries, etc. However they can require additional customization to incorporate domain-specific knowledge and up-to-date information.
In this blog we will build serverless pipelines that fine-tuning LLM Models using Amazon SageMaker, and deploying these models. Using AWS CDK as infrastructure as code, the solution separates training workflow from inference workflow, ensuring the production workloads remain stable and unaffected during model training and update. Additionally, leveraging Amazon AppConfig allows dynamic configuration updates without requiring redeployment.
The app is built using Kiro🔥

Architecture Overview

The system is composed of two main pipeline as the diagram below:

  • Training/Fine-tuning pipeline: responsible for data preparation, model fine-tuning, evaluation, and approval.

  • Inference pipeline: responsible for serving production request using the approved model.

1. Training pipeline

The training process is responsible for fine-tuning LLM models using AWS resources. While we rely on AWS resources to do the heavy job. The workflow is manually initiated.
Data Preparation:

  • Training datasets are downloaded from 1 in three sources: Hugging Face, Amazon public data, or synthesis data.
  • The data is formatted and splitted into 3 small sets: Training dataset Validation dataset Test dataset
  • The datasets are then uploaded to S3 bucket to be ready for training process Model Fine-Tuning:
  • Fine tuning is executed using Amazon SageMaker, and triggered by a python script.
  • The script supports both full-training and LoRA options with LoRa as default.
  • After the training job completes, evaluation metrics will be generated. If you satisfy with the result, register the model in SageMaker model registry and wait for approval. Automated deployment trigger: Once the model is approved, a lambda function will be triggered automatically to:
  • Create a new SageMaker endpoint.
  • Update AWS Systems Manager Parameter Store with the new endpoint.

2. Inference pipeline

This pipeline is responsible for handling realtime review summary requests from users. Incoming requests will be received via API Gateway, which invokes a lambda function to process it. The generated summaries will be stored in S3 bucket for later purposes such as auditing, monitoring, or analytical purposes.
To enable comparison between the foundation LLM model and the fine-tune model, the Lambda function first invokes a Foundation model. It then invokes the Amazon SageMaker endpoint created by the training pipeline above.
AWS AppConfig is used to manage runtime settings such as which model to invoke. This approach enables dynamic model switching without redeploying the whole application.


AWS Resources:

  • Amazon SageMaker
  • AWS S3 buckets
  • Amazon API Gateway
  • AWS Lambda
  • Amazon AppConfig
  • Amazon Parameter Store
  • Amazon Event Bridge
  • Amazon Bedrock
  • AWS Identity and Access Management (IAM)
  • Amazon CloudWatch

Prerequisites:

An AWS account that has been bootstrapped for AWS CDK
Environment setup:

  • Note.js
  • Typescript
  • AWS CDK Toolkit
  • Docker (used for bundling Lambda functions when deploying) AWS Credentials: keep them handy so you can deploy the stacks

Building the app

1. AppConfig stack

This will leverage Amazon AppConfig to store the config for runtime.First we define the json for each environment:

{
  "bedrock": {
    "modelId": "anthropic.claude-3-haiku-20240307-v1:0",
    "maxTokens": 200,
    "temperature": 0.5,
    "topP": 0.9
  },
  "sagemaker": {
    "enabled": true,
    "timeout": 30000,
    "models": {
      "stable": {
        "endpointName": "endpoint-kate",
        "description": "Kate's development model",
        "weight": 100
      }
    },
    "strategy": "weighted"
  },
  "rag": {
    "enabled": false,
    "topK": 3
  },
  "features": {
    "sentimentAnalysis": true,
    "caching": false,
    "useNewSummarizationPrompt": false,
    "enableAdvancedRAG": false,
    "useMultiModelEnsemble": false
  },
  "abTesting": {
    "enabled": false,
    "rules": []
  },
  "monitoring": {
    "logABTestAssignments": true,
    "trackModelPerformance": true,
    "metricsNamespace": "LLMPipeline/Kate"
  }
}
Enter fullscreen mode Exit fullscreen mode

Then we create the stack

import * as cdk from 'aws-cdk-lib';
import { Construct } from 'constructs';
import * as appconfig from 'aws-cdk-lib/aws-appconfig';
import * as iam from 'aws-cdk-lib/aws-iam';
import * as fs from 'fs';
import * as path from 'path';
import { EnvironmentConfig } from './utils';

export interface AppConfigStackProps extends cdk.StackProps {
  config: EnvironmentConfig;
}

export class AppConfigStack extends cdk.Stack {
  public readonly application: appconfig.CfnApplication;
  public readonly appConfigEnvironment: appconfig.CfnEnvironment;
  public readonly configurationProfile: appconfig.CfnConfigurationProfile;

  constructor(scope: Construct, id: string, props: AppConfigStackProps) {
    super(scope, id, props);

    const { config } = props;

    // Create AppConfig Application
    this.application = new appconfig.CfnApplication(this, 'Application', {
      name: `llm-pipeline-${config.environmentName}`,
      description: 'Configuration for LLM Pipeline',
    });

    // Create AppConfig Environment
    this.appConfigEnvironment = new appconfig.CfnEnvironment(this, 'Environment', {
      applicationId: this.application.ref,
      name: config.environmentName,
      description: `${config.environmentName} environment`,
    });

    // Create Configuration Profile
    this.configurationProfile = new appconfig.CfnConfigurationProfile(this, 'ConfigProfile', {
      applicationId: this.application.ref,
      name: 'runtime-config',
      description: 'Runtime configuration for Lambda functions',
      locationUri: 'hosted',
      type: 'AWS.Freeform',
    });

    // Initial configuration with A/B testing support
    // These are RUNTIME settings that can be updated without redeployment
    // Loaded from config/appconfig-{environment}.json
    const configPath = path.join(__dirname, `../config/appconfig-${config.environmentName}.json`);

    let configContent: string;
    if (!fs.existsSync(configPath)) {
      throw new Error(
        `\n========================================\n` +
        `ERROR: AppConfig file missing for environment "${config.environmentName}"\n` +
        `========================================\n` +
        `Expected file: config/appconfig-${config.environmentName}.json\n` +
        `Full path: ${configPath}\n\n` +
        `Please create this file with runtime configuration.\n` +
        `You can copy from an existing environment:\n` +
        `  cp config/appconfig-kate.json config/appconfig-${config.environmentName}.json\n` +
        `========================================\n`
      );
    }

    try {
      configContent = fs.readFileSync(configPath, 'utf8');
      // Validate it's valid JSON
      JSON.parse(configContent);
      console.log(`✓ Loaded AppConfig for "${config.environmentName}" from: ${configPath}`);
    } catch (error) {
      throw new Error(
        `\n========================================\n` +
        `ERROR: Invalid AppConfig JSON for environment "${config.environmentName}"\n` +
        `========================================\n` +
        `File: config/appconfig-${config.environmentName}.json\n` +
        `Error: ${error instanceof Error ? error.message : String(error)}\n\n` +
        `Please ensure the file contains valid JSON.\n` +
        `Check for:\n` +
        `  - Missing commas\n` +
        `  - Trailing commas\n` +
        `  - Unquoted keys\n` +
        `  - Invalid escape sequences\n` +
        `========================================\n`
      );
    }

    // Create deployment strategy (immediate deployment)
    const deploymentStrategy = new appconfig.CfnDeploymentStrategy(this, 'DeploymentStrategy', {
      name: `immediate-${config.environmentName}`,
      deploymentDurationInMinutes: 0,
      growthFactor: 100,
      replicateTo: 'NONE',
      finalBakeTimeInMinutes: 0,
    });

    // Create hosted configuration version
    const configVersion = new appconfig.CfnHostedConfigurationVersion(this, 'ConfigVersion', {
      applicationId: this.application.ref,
      configurationProfileId: this.configurationProfile.ref,
      content: configContent,
      contentType: 'application/json',
      description: 'Initial configuration',
    });

    // Automatically deploy the configuration
    new appconfig.CfnDeployment(this, 'Deployment', {
      applicationId: this.application.ref,
      environmentId: this.appConfigEnvironment.ref,
      deploymentStrategyId: deploymentStrategy.ref,
      configurationProfileId: this.configurationProfile.ref,
      configurationVersion: configVersion.ref,
      description: 'Automatic deployment from CDK',
    });

    // Outputs
    new cdk.CfnOutput(this, 'ApplicationId', {
      value: this.application.ref,
      description: 'AppConfig Application ID',
      exportName: `${config.environmentName}-appconfig-app-id`,
    });

    new cdk.CfnOutput(this, 'EnvironmentId', {
      value: this.appConfigEnvironment.ref,
      description: 'AppConfig Environment ID',
      exportName: `${config.environmentName}-appconfig-env-id`,
    });

    new cdk.CfnOutput(this, 'ConfigurationProfileId', {
      value: this.configurationProfile.ref,
      description: 'AppConfig Configuration Profile ID',
      exportName: `${config.environmentName}-appconfig-profile-id`,
    });
  }

  /**
   * Grant Lambda function permission to read AppConfig
   */
  public grantRead(grantee: iam.IGrantable): void {
    grantee.grantPrincipal.addToPrincipalPolicy(
      new iam.PolicyStatement({
        effect: iam.Effect.ALLOW,
        actions: [
          'appconfig:GetConfiguration',
          'appconfig:GetLatestConfiguration',
          'appconfig:StartConfigurationSession',
        ],
        resources: ['*'],
      })
    );
  }
}
Enter fullscreen mode Exit fullscreen mode

2. Fine-Tuning Model Pipeline

2.1 Create the pipeline

This pipeline will create these AWS resources below:

  • S3 Buckets: training data S3 bucket and model artifact S3 bucket
  • SageMaker IAM role used for training job
  • SSM parameter store to store the endpoint version
  • EventBridge rule to trigger the process of when the model is approved
  • Lambda function to deploy the approved models
  • Cloudwatch logs
import * as cdk from 'aws-cdk-lib';
import { Construct } from 'constructs';
import * as s3 from 'aws-cdk-lib/aws-s3';
import * as lambda from 'aws-cdk-lib/aws-lambda';
import { PythonFunction } from '@aws-cdk/aws-lambda-python-alpha';
import * as iam from 'aws-cdk-lib/aws-iam';
import * as events from 'aws-cdk-lib/aws-events';
import * as targets from 'aws-cdk-lib/aws-events-targets';
import * as ssm from 'aws-cdk-lib/aws-ssm';
import * as logs from 'aws-cdk-lib/aws-logs';
import { EnvironmentConfig } from './utils';

export interface TrainingPipelineStackProps extends cdk.StackProps {
  config: EnvironmentConfig;
}

export class TrainingPipelineStack extends cdk.Stack {
  public readonly trainingBucket: s3.Bucket;
  public readonly modelBucket: s3.Bucket;
  public readonly endpointParameter: ssm.StringParameter;

  constructor(scope: Construct, id: string, props: TrainingPipelineStackProps) {
    super(scope, id, props);

    const { config } = props;

    // ========================================
    // S3 Buckets for Training
    // ========================================

    // NOTE: Using DESTROY for cost-saving during development
    // For production, change to RETAIN to preserve training data and models
    this.trainingBucket = new s3.Bucket(this, 'TrainingDataBucket', {
      bucketName: `training-data-${config.environmentName}-${cdk.Aws.ACCOUNT_ID}`,
      removalPolicy: cdk.RemovalPolicy.DESTROY,
      autoDeleteObjects: true,
      versioned: true,
      encryption: s3.BucketEncryption.S3_MANAGED,
      lifecycleRules: [
        {
          id: 'DeleteOldVersions',
          noncurrentVersionExpiration: cdk.Duration.days(90),
        },
      ],
    });

    // NOTE: Using DESTROY for cost-saving during development
    // For production, change to RETAIN to preserve model artifacts
    this.modelBucket = new s3.Bucket(this, 'ModelArtifactsBucket', {
      bucketName: `model-artifacts-${config.environmentName}-${cdk.Aws.ACCOUNT_ID}`,
      removalPolicy: cdk.RemovalPolicy.DESTROY,
      autoDeleteObjects: true,
      versioned: true,
      encryption: s3.BucketEncryption.S3_MANAGED,
    });

    // ========================================
    // Parameter Store for Active Endpoint
    // ========================================

    this.endpointParameter = new ssm.StringParameter(this, 'ActiveEndpointParameter', {
      parameterName: `/summarizer/${config.environmentName}/active-endpoint`,
      stringValue: 'none',
      description: 'Active SageMaker endpoint name for inference',
      tier: ssm.ParameterTier.STANDARD,
    });

    // ========================================
    // IAM Role for SageMaker Training
    // ========================================

    const sagemakerRole = new iam.Role(this, 'SageMakerTrainingRole', {
      assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
      managedPolicies: [
        iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'),
      ],
    });

    this.trainingBucket.grantReadWrite(sagemakerRole);
    this.modelBucket.grantReadWrite(sagemakerRole);

    // ========================================
    // Lambda: Update Endpoint on Model Approval
    // ========================================

    const updateEndpointLogGroup = new logs.LogGroup(this, 'UpdateEndpointLogGroup', {
      logGroupName: `/aws/lambda/update-endpoint-${config.environmentName}`,
      retention: logs.RetentionDays.ONE_WEEK,
      removalPolicy: cdk.RemovalPolicy.DESTROY,
    });

    const updateEndpointFn = new PythonFunction(this, 'UpdateEndpointFunction', {
      functionName: `update-endpoint-${config.environmentName}`,
      entry: 'src/lambdas/update-endpoint',
      runtime: lambda.Runtime.PYTHON_3_11,
      index: 'handler.py',
      handler: 'handler',
      description: `Update SageMaker endpoint for ${config.environmentName}`,
      timeout: cdk.Duration.minutes(5),
      memorySize: 256,
      environment: {
        PARAMETER_NAME: this.endpointParameter.parameterName,
        ENVIRONMENT: config.environmentName,
        SAGEMAKER_ROLE_ARN: sagemakerRole.roleArn,
      },
      logGroup: updateEndpointLogGroup,
    });

    // Grant permissions
    this.endpointParameter.grantRead(updateEndpointFn);
    this.endpointParameter.grantWrite(updateEndpointFn);

    updateEndpointFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: [
        'sagemaker:DescribeModelPackage',
        'sagemaker:CreateModel',
        'sagemaker:CreateEndpoint',
        'sagemaker:CreateEndpointConfig',
        'sagemaker:UpdateEndpoint',
        'sagemaker:DescribeEndpoint',
      ],
      resources: ['*'],
    }));

    // Grant permission to pass the SageMaker execution role
    updateEndpointFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: ['iam:PassRole'],
      resources: [sagemakerRole.roleArn],
    }));

    // ========================================
    // EventBridge: Trigger on Model Approval
    // ========================================

    const modelApprovalRule = new events.Rule(this, 'ModelApprovalRule', {
      ruleName: `model-approval-${config.environmentName}`,
      description: 'Trigger endpoint update when SageMaker model is approved',
      eventPattern: {
        source: ['aws.sagemaker'],
        detailType: ['SageMaker Model Package State Change'],
        detail: {
          ModelApprovalStatus: ['Approved'],
        },
      },
    });

    modelApprovalRule.addTarget(new targets.LambdaFunction(updateEndpointFn));

    // ========================================
    // Outputs
    // ========================================

    new cdk.CfnOutput(this, 'TrainingBucketName', {
      value: this.trainingBucket.bucketName,
      description: 'S3 bucket for training data',
      exportName: `${config.environmentName}-training-bucket`,
    });

    new cdk.CfnOutput(this, 'ModelBucketName', {
      value: this.modelBucket.bucketName,
      description: 'S3 bucket for model artifacts',
      exportName: `${config.environmentName}-model-bucket`,
    });

    new cdk.CfnOutput(this, 'EndpointParameterName', {
      value: this.endpointParameter.parameterName,
      description: 'Parameter Store key for active endpoint',
      exportName: `${config.environmentName}-endpoint-parameter`,
    });

    new cdk.CfnOutput(this, 'SageMakerRoleArn', {
      value: sagemakerRole.roleArn,
      description: 'IAM role for SageMaker training jobs',
      exportName: `${config.environmentName}-sagemaker-role`,
    });
  }
}
Enter fullscreen mode Exit fullscreen mode

2.2 Create the scripts:

Prepare training datasets
Create a python script to download the datasets. The datasets can be downloaded from huggingface, amazon reviews, or generate

#!/usr/bin/env python3
"""
Download and prepare training data from public datasets

This script downloads customer review data and formats it for SageMaker training.
It supports multiple sources:
1. Hugging Face Datasets (recommended - easy and reliable)
2. Amazon Customer Reviews (real data from AWS Open Data Registry)
3. Synthetic data (generated for testing)

Output: training_data/ folder with train.jsonl, validation.jsonl, test.jsonl

Usage:
    # Download from Hugging Face (recommended)
    python scripts/download_training_data.py --source huggingface --dataset amazon_polarity --num-samples 5000

    # Generate synthetic data for testing
    python scripts/download_training_data.py --source synthetic --num-samples 1000

    # Download real Amazon reviews
    python scripts/download_training_data.py --source amazon --max-samples 5000
"""

import os
import json
import gzip
import argparse
import urllib.request
import ssl
from pathlib import Path
from typing import List, Dict
import random

# Fix SSL certificate verification issue on macOS
ssl._create_default_https_context = ssl._create_unverified_context


def download_huggingface_dataset(
    output_dir: Path, dataset_name: str = "amazon_polarity", max_samples: int = 5000
):
    """
    Download dataset from Hugging Face
    Source: https://huggingface.co/datasets

    Popular datasets:
    - amazon_polarity: Amazon reviews (positive/negative) - NO SUMMARIES
    - yelp_review_full: Yelp reviews with 1-5 star ratings - NO SUMMARIES
    - imdb: Movie reviews - NO SUMMARIES
    - rotten_tomatoes: Movie reviews - NO SUMMARIES
    - app_reviews: Mobile app reviews - NO SUMMARIES
    - cnn_dailymail: News articles WITH SUMMARIES (recommended for summarization)
    - xsum: News WITH SUMMARIES (extreme summarization)
    - samsum: Dialogues WITH SUMMARIES
    """
    print(f"\n📦 Downloading from Hugging Face: {dataset_name}")
    print(f"This may take a few minutes...")

    try:
        from datasets import load_dataset
    except ImportError:
        print("\n❌ Error: 'datasets' library not installed")
        print("Install it with: pip install datasets")
        return []

    try:
        # Load dataset with config if needed
        print(f"Loading dataset '{dataset_name}'...")

        # Datasets that need config versions
        if dataset_name == 'cnn_dailymail':
            dataset = load_dataset(dataset_name, '3.0.0')
        elif dataset_name == 'xsum':
            dataset = load_dataset(dataset_name)
        elif dataset_name == 'samsum':
            dataset = load_dataset(dataset_name)
        else:
            # Regular datasets (reviews)
            dataset = load_dataset(dataset_name)

        # Get train split
        train_data = dataset["train"]

        # Process samples
        reviews = []
        count = 0

        print(f"Processing samples...")
        for item in train_data:
            if count >= max_samples:
                break

            # Handle summarization datasets differently
            if dataset_name == 'cnn_dailymail':
                text = item.get('article', '')
                summary = item.get('highlights', '')
                sentiment = 'neutral'
            elif dataset_name == 'xsum':
                text = item.get('document', '')
                summary = item.get('summary', '')
                sentiment = 'neutral'
            elif dataset_name == 'samsum':
                text = item.get('dialogue', '')
                summary = item.get('summary', '')
                sentiment = 'neutral'
            else:
                # Review datasets - extract text and label
                text = None
                label = None

                # Try common field names
                if "content" in item:
                    text = item["content"]
                elif "text" in item:
                    text = item["text"]
                elif "review" in item:
                    text = item["review"]

                if "label" in item:
                    label = item["label"]
                elif "sentiment" in item:
                    label = item["sentiment"]
                elif "stars" in item:
                    label = item["stars"]

                if not text:
                    continue

                # Skip very short reviews
                if len(text) < 50:
                    continue

                # Determine sentiment from label
                sentiment = "neutral"
                if isinstance(label, int):
                    if label >= 4 or label == 1:  # 5-star or positive binary
                        sentiment = "positive"
                    elif label <= 2 or label == 0:  # 1-2 star or negative binary
                        sentiment = "negative"
                    else:
                        sentiment = "neutral"
                elif isinstance(label, str):
                    sentiment = label.lower()

                # Create summary (first 150 chars or extract key points)
                # NOTE: This is NOT a real summary, just for demo purposes
                summary = create_summary_from_text(text)

            # Skip if no text or summary
            if not text or not summary or len(text) < 50:
                continue

            reviews.append(
                {
                    "text": text,
                    "summary": summary,
                    "sentiment": sentiment,
                    "source": dataset_name,
                }
            )

            count += 1
            if count % 500 == 0:
                print(f"Processed {count} samples...")

        print(f"✅ Processed {len(reviews)} samples from Hugging Face dataset")
        return reviews

    except Exception as e:
        print(f"\n⚠️  Error loading dataset: {str(e)}")
        print(f"\nAvailable datasets:")
        print("  Summarization (recommended):")
        print("    - cnn_dailymail (news articles with summaries)")
        print("    - xsum (news with one-sentence summaries)")
        print("    - samsum (dialogues with summaries)")
        print("  Reviews (no real summaries):")
        print("    - amazon_polarity")
        print("    - yelp_review_full")
        print("    - imdb")
        print("    - rotten_tomatoes")
        print("    - app_reviews")
        print(
            "\nTry: python scripts/download_training_data.py --source huggingface --dataset cnn_dailymail"
        )
        return []


def create_summary_from_text(text: str, max_length: int = 150) -> str:
    """
    Create a simple summary from review text
    Takes first sentence or first N characters
    """
    # Try to get first sentence
    sentences = text.split(".")
    if sentences and len(sentences[0]) > 20:
        summary = sentences[0].strip() + "."
        if len(summary) <= max_length:
            return summary

    # Otherwise, take first N characters
    if len(text) <= max_length:
        return text

    return text[:max_length].rsplit(" ", 1)[0] + "..."


def download_file(url: str, output_path: str):
    """Download file from URL with progress"""
    print(f"Downloading from {url}...")

    def progress_hook(count, block_size, total_size):
        percent = int(count * block_size * 100 / total_size)
        print(f"\rProgress: {percent}%", end="", flush=True)

    urllib.request.urlretrieve(url, output_path, progress_hook)
    print("\nDownload complete!")


def download_amazon_reviews(
    output_dir: Path, category: str = "Electronics", max_samples: int = 10000
):
    """
    Download Amazon Customer Reviews dataset
    Source: https://registry.opendata.aws/amazon-reviews/
    """
    print(f"\n📦 Downloading Amazon Reviews - {category} category")
    print(f"This may take a few minutes...")

    # Amazon Reviews Open Data URLs
    base_url = "https://s3.amazonaws.com/amazon-reviews-pds/tsv"
    filename = f"amazon_reviews_us_{category}_v1_00.tsv.gz"
    url = f"{base_url}/{filename}"

    # Download
    temp_file = output_dir / filename

    try:
        download_file(url, str(temp_file))
    except Exception as e:
        print(f"\n⚠️  Download failed: {str(e)}")
        print(f"\nTrying alternative method using AWS CLI...")

        # Try using AWS CLI as fallback
        import subprocess

        try:
            result = subprocess.run(
                [
                    "aws",
                    "s3",
                    "cp",
                    f"s3://amazon-reviews-pds/tsv/{filename}",
                    str(temp_file),
                ],
                capture_output=True,
                text=True,
            )
            if result.returncode != 0:
                print(f"AWS CLI also failed: {result.stderr}")
                print(f"\n💡 Tip: You can manually download from:")
                print(f"   {url}")
                print(f"   Save to: {temp_file}")
                return []
        except FileNotFoundError:
            print(f"AWS CLI not found. Please install it or download manually from:")
            print(f"   {url}")
            return []

    # Parse and convert to JSONL
    print(f"\nProcessing reviews...")
    reviews = []

    with gzip.open(temp_file, "rt", encoding="utf-8") as f:
        # Skip header
        header = f.readline().strip().split("\t")

        # Find column indices
        try:
            review_idx = header.index("review_body")
            headline_idx = header.index("review_headline")
            rating_idx = header.index("star_rating")
        except ValueError as e:
            print(f"Error: Could not find required columns in dataset")
            return []

        count = 0
        for line in f:
            if count >= max_samples:
                break

            try:
                fields = line.strip().split("\t")
                if len(fields) <= max(review_idx, headline_idx, rating_idx):
                    continue

                review_text = fields[review_idx]
                headline = fields[headline_idx]
                rating = int(fields[rating_idx])

                # Skip empty reviews
                if not review_text or len(review_text) < 50:
                    continue

                # Determine sentiment from rating
                if rating >= 4:
                    sentiment = "positive"
                elif rating <= 2:
                    sentiment = "negative"
                else:
                    sentiment = "neutral"

                # Use headline as summary (not perfect but works for training)
                # In production, you'd want human-written summaries
                summary = headline if headline else review_text[:100]

                reviews.append(
                    {
                        "text": review_text,
                        "summary": summary,
                        "sentiment": sentiment,
                        "rating": rating,
                    }
                )

                count += 1
                if count % 1000 == 0:
                    print(f"Processed {count} reviews...")

            except Exception as e:
                continue

    # Clean up temp file
    temp_file.unlink()

    print(f"✅ Processed {len(reviews)} reviews from Amazon dataset")
    return reviews


def create_synthetic_data(num_samples: int = 1000) -> List[Dict]:
    """
    Create synthetic training data for testing
    Use this if you can't download real data
    """
    print(f"\n🔧 Generating {num_samples} synthetic reviews...")

    templates = {
        "positive": [
            (
                "This product is absolutely amazing! {feature1} and {feature2}. Highly recommend to anyone looking for quality.",
                "Excellent product with great {feature1} and {feature2}. Highly recommended.",
            ),
            (
                "I'm very impressed with this purchase. The {feature1} exceeded my expectations and {feature2}. Worth every penny!",
                "Very satisfied with {feature1} and {feature2}. Great value.",
            ),
            (
                "Outstanding quality! {feature1} is incredible and {feature2}. Best purchase I've made this year.",
                "Outstanding {feature1} and {feature2}. Excellent purchase.",
            ),
        ],
        "negative": [
            (
                "Very disappointed with this product. {issue1} and {issue2}. Would not recommend.",
                "Poor quality with {issue1} and {issue2}. Not recommended.",
            ),
            (
                "This is a waste of money. {issue1} after just a few days and {issue2}. Terrible experience.",
                "Product failed quickly with {issue1} and {issue2}. Waste of money.",
            ),
            (
                "Do not buy this! {issue1} and {issue2}. Customer service was unhelpful too.",
                "Major issues with {issue1} and {issue2}. Poor support.",
            ),
        ],
        "neutral": [
            (
                "It's okay for the price. {aspect1} but {aspect2}. Nothing special.",
                "Average product. {aspect1} but {aspect2}.",
            ),
            (
                "Does what it's supposed to do. {aspect1} though {aspect2}. Acceptable.",
                "Functional product. {aspect1} with {aspect2}.",
            ),
            (
                "Mixed feelings about this. {aspect1} but {aspect2}. Could be better.",
                "Mixed quality. {aspect1} but {aspect2}.",
            ),
        ],
    }

    features = [
        "The battery life is excellent",
        "The build quality feels premium",
        "The performance is outstanding",
        "The design is beautiful",
        "The screen quality is amazing",
        "The sound quality is superb",
        "The camera takes great photos",
        "The speed is impressive",
    ]

    issues = [
        "It stopped working",
        "The battery drains quickly",
        "The build quality is poor",
        "It feels cheap and flimsy",
        "The performance is sluggish",
        "It overheats constantly",
        "The screen is dim",
        "The sound quality is terrible",
    ]

    aspects = [
        "The price is reasonable",
        "It works as advertised",
        "The design is acceptable",
        "The features are basic",
        "The quality is average",
        "The performance is adequate",
    ]

    reviews = []
    sentiments = ["positive", "negative", "neutral"]

    for i in range(num_samples):
        sentiment = random.choice(sentiments)
        template_text, template_summary = random.choice(templates[sentiment])

        if sentiment == "positive":
            text = template_text.format(
                feature1=random.choice(features), feature2=random.choice(features)
            )
            summary = template_summary.format(
                feature1=random.choice(features).lower(),
                feature2=random.choice(features).lower(),
            )
        elif sentiment == "negative":
            text = template_text.format(
                issue1=random.choice(issues), issue2=random.choice(issues)
            )
            summary = template_summary.format(
                issue1=random.choice(issues).lower(),
                issue2=random.choice(issues).lower(),
            )
        else:
            text = template_text.format(
                aspect1=random.choice(aspects), aspect2=random.choice(aspects)
            )
            summary = template_summary.format(
                aspect1=random.choice(aspects).lower(),
                aspect2=random.choice(aspects).lower(),
            )

        reviews.append({"text": text, "summary": summary, "sentiment": sentiment})

    print(f"✅ Generated {len(reviews)} synthetic reviews")
    return reviews


def split_and_save_data(
    reviews: List[Dict], output_dir: Path, train_ratio=0.8, val_ratio=0.1
):
    """Split data into train/val/test and save as JSONL"""

    # Shuffle
    random.shuffle(reviews)

    # Calculate splits
    total = len(reviews)
    train_size = int(total * train_ratio)
    val_size = int(total * val_ratio)

    train_data = reviews[:train_size]
    val_data = reviews[train_size : train_size + val_size]
    test_data = reviews[train_size + val_size :]

    # Save files
    output_dir.mkdir(parents=True, exist_ok=True)

    def save_jsonl(data, filename):
        filepath = output_dir / filename
        with open(filepath, "w") as f:
            for item in data:
                f.write(json.dumps(item) + "\n")
        print(f"  ✓ {filename}: {len(data)} samples")

    print(f"\n💾 Saving data to {output_dir}/")
    save_jsonl(train_data, "train.jsonl")
    save_jsonl(val_data, "validation.jsonl")
    save_jsonl(test_data, "test.jsonl")

    print(f"\n📊 Data split:")
    print(f"  Training:   {len(train_data)} samples ({train_ratio*100:.0f}%)")
    print(f"  Validation: {len(val_data)} samples ({val_ratio*100:.0f}%)")
    print(
        f"  Test:       {len(test_data)} samples ({(1-train_ratio-val_ratio)*100:.0f}%)"
    )


def main():
    parser = argparse.ArgumentParser(
        description="Download and prepare training data for review summarization",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Download from Hugging Face (recommended)
  python scripts/download_training_data.py --source huggingface --dataset amazon_polarity --num-samples 5000

  # Download different Hugging Face dataset
  python scripts/download_training_data.py --source huggingface --dataset yelp_review_full --num-samples 3000

  # Download real Amazon reviews (Electronics)
  python scripts/download_training_data.py --source amazon --max-samples 5000

  # Generate synthetic data for testing
  python scripts/download_training_data.py --source synthetic --num-samples 1000

  # Custom output directory
  python scripts/download_training_data.py --source huggingface --dataset imdb --output-dir my_data/
        """,
    )

    parser.add_argument(
        "--source",
        type=str,
        default="huggingface",
        choices=["huggingface", "amazon", "synthetic"],
        help="Data source (default: huggingface)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="amazon_polarity",
        help="Hugging Face dataset name (default: amazon_polarity)",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="training_data",
        help="Output directory (default: training_data)",
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=5000,
        help="Max samples to download from Amazon (default: 5000)",
    )
    parser.add_argument(
        "--num-samples",
        type=int,
        default=5000,
        help="Number of samples to generate/download (default: 5000)",
    )
    parser.add_argument(
        "--category",
        type=str,
        default="Electronics",
        help="Amazon reviews category (default: Electronics)",
    )

    args = parser.parse_args()

    output_dir = Path(args.output_dir)

    print("=" * 60)
    print("📚 Training Data Preparation")
    print("=" * 60)

    # Get data based on source
    if args.source == "huggingface":
        reviews = download_huggingface_dataset(
            output_dir=output_dir,
            dataset_name=args.dataset,
            max_samples=args.num_samples,
        )
        if not reviews:
            print("\n❌ Failed to download from Hugging Face.")
            print("Please check your internet connection or try a different dataset.")
            return
    elif args.source == "amazon":
        reviews = download_amazon_reviews(
            output_dir=output_dir, category=args.category, max_samples=args.max_samples
        )
        if not reviews:
            print("\n❌ Failed to download Amazon reviews.")
            print("Please check your internet connection or AWS CLI configuration.")
            return
    else:
        reviews = create_synthetic_data(args.num_samples)

    # Split and save
    if reviews:
        split_and_save_data(reviews, output_dir)

        print("\n" + "=" * 60)
        print("✅ Data preparation complete!")
        print("=" * 60)
        print(f"\nNext steps:")
        print(f"1. Review the data in {output_dir}/")
        print(f"2. Upload to S3:")
        print(f"   python scripts/upload_training_data.py")
        print(f"3. Start training:")
        print(f"   python scripts/start_training.py")
    else:
        print("\n❌ No data was generated")

if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

The dataset will be downloaded from the source indicated when running the script, if not it will get from hugging face.
Since we use instruction fine-tuning, the dataset will be format as:

                {
                    "text": text,
                    "summary": summary,
                    "sentiment": sentiment,
                    "source": dataset_name,
                }
Enter fullscreen mode Exit fullscreen mode

before splitting into training, test and validation datasets
Upload to S3 bucket
Let’s create an script to help us upload the datasets into S3 bucket

#!/usr/bin/env python3
"""
Upload training data to S3 training bucket

This script uploads your prepared training data to the S3 bucket created by
the training pipeline stack. It automatically finds the correct bucket name
from CloudFormation outputs.

Prerequisites:
    1. Deploy training pipeline: cdk deploy TrainingPipeline
    2. Prepare data: python scripts/download_training_data.py

Usage:
    # Upload data for kate environment
    python scripts/upload_training_data.py

    # Upload for different environment
    python scripts/upload_training_data.py --environment dev
"""

import boto3
import argparse
from pathlib import Path
import os


def get_training_bucket(environment='kate'):
    """Get training bucket name from CloudFormation stack"""
    cfn = boto3.client('cloudformation')
    stack_name = f'training-pipeline-{environment}'

    try:
        response = cfn.describe_stacks(StackName=stack_name)
        outputs = response['Stacks'][0]['Outputs']

        for output in outputs:
            if output['OutputKey'] == 'TrainingBucketName':
                return output['OutputValue']

        print(f"❌ Error: Could not find TrainingBucketName in stack outputs")
        return None

    except Exception as e:
        print(f"❌ Error: Could not find stack '{stack_name}'")
        print(f"Make sure you've deployed the training pipeline first:")
        print(f"  cdk deploy TrainingPipeline")
        return None


def upload_directory(local_dir: Path, bucket_name: str, s3_prefix: str = ''):
    """Upload directory contents to S3"""
    s3 = boto3.client('s3')

    if not local_dir.exists():
        print(f"❌ Error: Directory not found: {local_dir}")
        print(f"\nRun this first to download training data:")
        print(f"  python scripts/download_training_data.py")
        return False

    # Get list of files
    files = list(local_dir.glob('*.jsonl'))

    if not files:
        print(f"❌ Error: No .jsonl files found in {local_dir}")
        print(f"\nExpected files:")
        print(f"  - train.jsonl")
        print(f"  - validation.jsonl")
        print(f"  - test.jsonl")
        return False

    print(f"\n📤 Uploading {len(files)} files to s3://{bucket_name}/{s3_prefix}")
    print("=" * 60)

    uploaded = 0
    for file_path in files:
        s3_key = f"{s3_prefix}{file_path.name}" if s3_prefix else file_path.name

        try:
            # Get file size
            file_size = file_path.stat().st_size
            file_size_mb = file_size / (1024 * 1024)

            print(f"  Uploading {file_path.name} ({file_size_mb:.2f} MB)...", end='', flush=True)

            # Upload with progress
            s3.upload_file(
                str(file_path),
                bucket_name,
                s3_key,
                Callback=lambda bytes_transferred: None
            )

            print(" ✓")
            uploaded += 1

        except Exception as e:
            print(f" ✗")
            print(f"    Error: {str(e)}")

    print("=" * 60)
    print(f"✅ Uploaded {uploaded}/{len(files)} files successfully")

    return uploaded == len(files)


def verify_upload(bucket_name: str, s3_prefix: str = ''):
    """Verify files were uploaded correctly"""
    s3 = boto3.client('s3')

    print(f"\n🔍 Verifying upload...")

    try:
        response = s3.list_objects_v2(
            Bucket=bucket_name,
            Prefix=s3_prefix
        )

        if 'Contents' not in response:
            print("❌ No files found in bucket")
            return False

        print(f"\n📁 Files in s3://{bucket_name}/{s3_prefix}")
        print("=" * 60)

        total_size = 0
        for obj in response['Contents']:
            key = obj['Key']
            size = obj['Size']
            size_mb = size / (1024 * 1024)
            total_size += size
            print(f"  ✓ {key} ({size_mb:.2f} MB)")

        total_size_mb = total_size / (1024 * 1024)
        print("=" * 60)
        print(f"Total: {len(response['Contents'])} files, {total_size_mb:.2f} MB")

        return True

    except Exception as e:
        print(f"❌ Error verifying upload: {str(e)}")
        return False


def main():
    parser = argparse.ArgumentParser(
        description='Upload training data to S3',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Upload data for kate environment
  python scripts/upload_training_data.py

  # Upload for different environment
  python scripts/upload_training_data.py --environment dev

  # Upload from custom directory
  python scripts/upload_training_data.py --data-dir my_data/

  # Upload to specific S3 prefix
  python scripts/upload_training_data.py --s3-prefix data/v1/
        """
    )

    parser.add_argument('--environment', type=str, default='kate',
                        help='Environment name (default: kate)')
    parser.add_argument('--data-dir', type=str, default='training_data',
                        help='Local data directory (default: training_data)')
    parser.add_argument('--s3-prefix', type=str, default='',
                        help='S3 prefix/folder (default: root)')

    args = parser.parse_args()

    print("=" * 60)
    print("📤 Upload Training Data to S3")
    print("=" * 60)

    # Get training bucket
    print(f"\n🔍 Looking up training bucket for environment: {args.environment}")
    bucket_name = get_training_bucket(args.environment)

    if not bucket_name:
        return

    print(f"✓ Found bucket: {bucket_name}")

    # Upload files
    local_dir = Path(args.data_dir)
    success = upload_directory(local_dir, bucket_name, args.s3_prefix)

    if not success:
        return

    # Verify upload
    verify_upload(bucket_name, args.s3_prefix)

    print("\n" + "=" * 60)
    print("✅ Upload complete!")
    print("=" * 60)
    print(f"\nNext steps:")
    print(f"1. Start training job:")
    print(f"   python scripts/start_training.py --environment {args.environment}")
    print(f"\n2. Monitor training:")
    print(f"   - AWS Console: https://console.aws.amazon.com/sagemaker/home#/jobs")
    print(f"   - CLI: aws sagemaker list-training-jobs --sort-by CreationTime --sort-order Descending")


if __name__ == '__main__':
    main()
Enter fullscreen mode Exit fullscreen mode

Training script

#!/usr/bin/env python3
"""
Start a SageMaker training job for fine-tuning review summarization model

This script starts a SageMaker training job that fine-tunes a T5 or DistilBERT
model on your review data. It automatically configures the job using resources
from your deployed training pipeline stack.

Prerequisites:
    1. Deploy training pipeline: cdk deploy TrainingPipeline
    2. Prepare data: python scripts/download_training_data.py
    3. Upload data: python scripts/upload_training_data.py

Usage:
    # Start training with defaults (t5-small, 3 epochs, ml.g4dn.xlarge GPU)
    python scripts/start_training.py

    # Custom hyperparameters
    python scripts/start_training.py --epochs 5 --batch-size 16 --learning-rate 3e-5

    # Use GPU for faster training
    python scripts/start_training.py --instance-type ml.p3.2xlarge
"""

import boto3
import argparse
from datetime import datetime
import os

# Get region from environment or use default
REGION = os.environ.get('AWS_REGION') or os.environ.get('AWS_DEFAULT_REGION') or 'ap-southeast-2'

sagemaker_client = boto3.client('sagemaker', region_name=REGION)
cfn = boto3.client('cloudformation', region_name=REGION)
s3 = boto3.client('s3', region_name=REGION)
sts = boto3.client('sts', region_name=REGION)


def get_stack_outputs(stack_name):
    """Get outputs from CloudFormation stack"""
    response = cfn.describe_stacks(StackName=stack_name)
    outputs = {}
    for output in response['Stacks'][0]['Outputs']:
        outputs[output['OutputKey']] = output['OutputValue']
    return outputs


def upload_training_code(model_bucket):
    """Upload training script to S3"""
    import tarfile
    import tempfile
    import os

    # Create a temporary tar.gz file with the training code
    with tempfile.NamedTemporaryFile(suffix='.tar.gz', delete=False) as tmp:
        tmp_path = tmp.name

    try:
        with tarfile.open(tmp_path, 'w:gz') as tar:
            tar.add('sagemaker/train.py', arcname='train.py')
            tar.add('sagemaker/requirements.txt', arcname='requirements.txt')

        # Upload to S3
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        s3_key = f'code/sourcedir-{timestamp}.tar.gz'
        s3.upload_file(tmp_path, model_bucket, s3_key)

        return f's3://{model_bucket}/{s3_key}'
    finally:
        if os.path.exists(tmp_path):
            os.remove(tmp_path)


def get_training_image():
    """Get the PyTorch training container image for the current region"""
    region = boto3.session.Session().region_name

    # PyTorch 2.0 training image
    pytorch_version = '2.0.1'
    python_version = 'py310'

    # ECR image URI format
    image_uri = f'763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:{pytorch_version}-gpu-{python_version}-cu118-ubuntu20.04-sagemaker'

    return image_uri


def start_training_job(
    environment='kate',
    model_name='t5-small',
    epochs=3,
    batch_size=8,
    learning_rate=2e-5,
    instance_type='ml.m5.xlarge',
    use_lora=True,
    lora_r=8,
    lora_alpha=32,
    lora_dropout=0.1
):
    """Start a SageMaker training job"""

    # Get stack outputs
    stack_name = f'training-pipeline-{environment}'
    print(f"Getting outputs from stack: {stack_name}")

    try:
        outputs = get_stack_outputs(stack_name)
    except Exception:
        print(f"Error: Could not find stack '{stack_name}'")
        print("Make sure you've deployed the training pipeline first:")
        print("  cdk deploy TrainingPipeline")
        return

    training_bucket = outputs['TrainingBucketName']
    model_bucket = outputs['ModelBucketName']
    sagemaker_role = outputs['SageMakerRoleArn']

    print(f"Training bucket: {training_bucket}")
    print(f"Model bucket: {model_bucket}")
    print(f"SageMaker role: {sagemaker_role}")

    # Upload training code to S3
    print("\nUploading training code to S3...")
    source_code_uri = upload_training_code(model_bucket)
    print(f"Training code uploaded to: {source_code_uri}")

    # Generate job name with timestamp
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    job_name = f'review-summarizer-{environment}-{timestamp}'

    # Training job configuration
    training_config = {
        'TrainingJobName': job_name,
        'RoleArn': sagemaker_role,
        'AlgorithmSpecification': {
            'TrainingImage': get_training_image(),
            'TrainingInputMode': 'File',
        },
        'InputDataConfig': [
            {
                'ChannelName': 'training',
                'DataSource': {
                    'S3DataSource': {
                        'S3DataType': 'S3Prefix',
                        'S3Uri': f's3://{training_bucket}/',
                        'S3DataDistributionType': 'FullyReplicated',
                    }
                },
                'ContentType': 'application/json',
                'CompressionType': 'None',
            }
        ],
        'OutputDataConfig': {
            'S3OutputPath': f's3://{model_bucket}/models/',
        },
        'ResourceConfig': {
            'InstanceType': instance_type,
            'InstanceCount': 1,
            'VolumeSizeInGB': 30,
        },
        'StoppingCondition': {
            'MaxRuntimeInSeconds': 86400,  # 24 hours
        },
        'HyperParameters': {
            'sagemaker_program': 'train.py',
            'sagemaker_submit_directory': source_code_uri,
            'epochs': str(epochs),
            'batch_size': str(batch_size),
            'learning_rate': str(learning_rate),
            'model_name': model_name,
            'use_lora': str(use_lora).lower(),
            'lora_r': str(lora_r),
            'lora_alpha': str(lora_alpha),
            'lora_dropout': str(lora_dropout),
        },
        'Tags': [
            {'Key': 'Environment', 'Value': environment},
            {'Key': 'Project', 'Value': 'review-summarizer'},
        ],
    }

    print(f"\nStarting training job: {job_name}")
    print(f"Model: {model_name}")
    print(f"Instance: {instance_type}")
    print(f"Training method: {'LoRA (Parameter-Efficient)' if use_lora else 'Full Fine-Tuning'}")
    print(f"Hyperparameters:")
    print(f"  - Epochs: {epochs}")
    print(f"  - Batch size: {batch_size}")
    print(f"  - Learning rate: {learning_rate}")
    if use_lora:
        print(f"  - LoRA rank: {lora_r}")
        print(f"  - LoRA alpha: {lora_alpha}")
        print(f"  - LoRA dropout: {lora_dropout}")

    try:
        response = sagemaker_client.create_training_job(**training_config)
        print(f"\n✅ Training job started successfully!")
        print(f"Job ARN: {response['TrainingJobArn']}")
        print(f"\nMonitor progress:")
        region = boto3.session.Session().region_name
        print(f"  - AWS Console: https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}")
        print(f"  - CLI: aws sagemaker describe-training-job --training-job-name {job_name}")
        print(f"\nView logs:")
        print(f"  aws logs tail /aws/sagemaker/TrainingJobs --follow --log-stream-name-prefix {job_name}")

    except Exception as e:
        print(f"\n❌ Error starting training job: {str(e)}")
        print("\nTroubleshooting:")
        print(f"1. Make sure training data exists in s3://{training_bucket}/")
        print("2. Check IAM role has necessary permissions")
        print("3. Verify the training image is available in your region")
        print("4. Check training script exists: sagemaker/train.py")





if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Start SageMaker training job for review summarization',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Start training with defaults
  python scripts/start_training.py

  # Custom hyperparameters
  python scripts/start_training.py --epochs 5 --batch-size 16

  # Use larger instance
  python scripts/start_training.py --instance-type ml.p3.2xlarge

  # Different environment
  python scripts/start_training.py --environment dev
        """
    )

    parser.add_argument('--environment', type=str, default='kate',
                        help='Environment name (default: kate)')
    parser.add_argument('--model-name', type=str, default='t5-small',
                        help='Base model to fine-tune (default: t5-small)')
    parser.add_argument('--epochs', type=int, default=3,
                        help='Number of training epochs (default: 3)')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='Training batch size (default: 8)')
    parser.add_argument('--learning-rate', type=float, default=2e-5,
                        help='Learning rate (default: 2e-5)')
    parser.add_argument('--instance-type', type=str, default='ml.g4dn.xlarge',
                        help='SageMaker instance type (default: ml.g4dn.xlarge)')
    parser.add_argument('--use-lora', action='store_true', default=True,
                        help='Enable LoRA fine-tuning (default: True)')
    parser.add_argument('--no-lora', dest='use_lora', action='store_false',
                        help='Disable LoRA and use full fine-tuning')
    parser.add_argument('--lora-r', type=int, default=8,
                        help='LoRA rank (default: 8)')
    parser.add_argument('--lora-alpha', type=int, default=32,
                        help='LoRA alpha scaling (default: 32)')
    parser.add_argument('--lora-dropout', type=float, default=0.1,
                        help='LoRA dropout (default: 0.1)')

    args = parser.parse_args()

    start_training_job(
        environment=args.environment,
        model_name=args.model_name,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        instance_type=args.instance_type,
        use_lora=args.use_lora,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
    )
Enter fullscreen mode Exit fullscreen mode

3. Inference Pipeline

3.1 Create the pipeline

This pipeline will create these AWS resources below:

  • S3 result bucket
  • API Gateway
  • Lambda function
  • IAM Roles
  • Cloudwatch logs
/**
 * Inference Pipeline Stack
 * 
 * This stack creates the infrastructure for online review summarization.
 * It implements a multi-stage processing pipeline:
 * 
 * 1. API Gateway - REST API endpoint for incoming requests
 * 2. Lambda Orchestrator - Coordinates the summarization pipeline
 * 3. Amazon Bedrock - Generates fast, general-purpose summaries
 * 4. Amazon OpenSearch - Retrieves relevant context via RAG (optional)
 * 5. SageMaker Endpoint - Refines summary with fine-tuned model (optional)
 * 6. S3 Results Bucket - Stores final summaries and metadata
 * 
 * Request Flow:
 * POST /summarize → Lambda → Bedrock → OpenSearch → SageMaker → S3 → Response
 * 
 * The Lambda function reads the active SageMaker endpoint from Parameter Store,
 * enabling zero-downtime model updates when new versions are deployed.
 */

import * as cdk from 'aws-cdk-lib';
import { Construct } from 'constructs';
import * as s3 from 'aws-cdk-lib/aws-s3';
import * as lambda from 'aws-cdk-lib/aws-lambda';
import { PythonFunction } from '@aws-cdk/aws-lambda-python-alpha';
import * as iam from 'aws-cdk-lib/aws-iam';
import * as apigateway from 'aws-cdk-lib/aws-apigateway';
import * as logs from 'aws-cdk-lib/aws-logs';
import * as ssm from 'aws-cdk-lib/aws-ssm';
import { EnvironmentConfig } from './utils';

export interface InferencePipelineStackProps extends cdk.StackProps {
  config: EnvironmentConfig;
  endpointParameterName: string;
  appConfigApplicationId?: string;
  appConfigEnvironmentId?: string;
  appConfigProfileId?: string;
}

export class InferencePipelineStack extends cdk.Stack {
  public readonly api: apigateway.RestApi;
  public readonly resultsBucket: s3.Bucket;
  public readonly summarizerFunction: lambda.Function;

  constructor(scope: Construct, id: string, props: InferencePipelineStackProps) {
    super(scope, id, props);

    const { config, endpointParameterName, appConfigApplicationId, appConfigEnvironmentId, appConfigProfileId } = props;

    // ========================================
    // S3 Bucket for Results
    // ========================================

    // NOTE: Using DESTROY for cost-saving during development
    // Results are temporary and can be safely deleted
    this.resultsBucket = new s3.Bucket(this, 'ResultsBucket', {
      bucketName: `summarizer-results-${config.environmentName}-${cdk.Aws.ACCOUNT_ID}`,
      removalPolicy: cdk.RemovalPolicy.DESTROY,
      autoDeleteObjects: true,
      encryption: s3.BucketEncryption.S3_MANAGED,
      lifecycleRules: [
        {
          id: 'DeleteOldResults',
          expiration: cdk.Duration.days(30),
        },
      ],
    });

    // ========================================
    // Lambda: Main Summarizer Function
    // ========================================

    const summarizerLogGroup = new logs.LogGroup(this, 'SummarizerLogGroup', {
      logGroupName: `/aws/lambda/summarizer-${config.environmentName}`,
      retention: logs.RetentionDays.ONE_WEEK,
      removalPolicy: cdk.RemovalPolicy.DESTROY,
    });

    const summarizerFn = new PythonFunction(this, 'SummarizerFunction', {
      functionName: `summarizer-${config.environmentName}`,
      entry: 'src/lambdas/summarizer',
      runtime: lambda.Runtime.PYTHON_3_11,
      index: 'handler.py',
      handler: 'handler',
      description: `Review summarization function for ${config.environmentName}`,
      timeout: cdk.Duration.seconds(120),
      memorySize: 1024,
      environment: {
        RESULTS_BUCKET: this.resultsBucket.bucketName,
        ENDPOINT_PARAMETER: endpointParameterName,
        ENVIRONMENT: config.environmentName,
        OPENSEARCH_ENDPOINT: process.env.OPENSEARCH_ENDPOINT || 'none',
        // AppConfig IDs (if provided)
        ...(appConfigApplicationId && { APPCONFIG_APPLICATION_ID: appConfigApplicationId }),
        ...(appConfigEnvironmentId && { APPCONFIG_ENVIRONMENT_ID: appConfigEnvironmentId }),
        ...(appConfigProfileId && { APPCONFIG_CONFIGURATION_PROFILE_ID: appConfigProfileId }),
      },
      logGroup: summarizerLogGroup,
    });

    // Expose Lambda function for AppConfig permissions
    this.summarizerFunction = summarizerFn;

    // Grant permissions
    this.resultsBucket.grantWrite(summarizerFn);

    summarizerFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: ['bedrock:InvokeModel'],
      resources: ['*'],
    }));

    summarizerFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: ['sagemaker:InvokeEndpoint'],
      resources: ['*'],
    }));

    summarizerFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: ['ssm:GetParameter'],
      resources: [
        `arn:aws:ssm:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:parameter${endpointParameterName}`,
      ],
    }));

    // OpenSearch permissions (if using)
    summarizerFn.addToRolePolicy(new iam.PolicyStatement({
      effect: iam.Effect.ALLOW,
      actions: [
        'aoss:APIAccessAll',
        'es:ESHttpGet',
        'es:ESHttpPost',
      ],
      resources: ['*'],
    }));

    // ========================================
    // API Gateway
    // ========================================

    if (config.enableApiGateway) {
      this.api = new apigateway.RestApi(this, 'SummarizerAPI', {
        restApiName: `review-summarizer-${config.environmentName}`,
        description: 'API for review summarization with RAG',
        deployOptions: {
          stageName: config.environmentName,
          loggingLevel: apigateway.MethodLoggingLevel.INFO,
          dataTraceEnabled: true,
          metricsEnabled: true,
        },
        defaultCorsPreflightOptions: {
          allowOrigins: apigateway.Cors.ALL_ORIGINS,
          allowMethods: apigateway.Cors.ALL_METHODS,
        },
      });

      // POST /summarize endpoint
      const summarize = this.api.root.addResource('summarize');
      summarize.addMethod('POST', new apigateway.LambdaIntegration(summarizerFn), {
        apiKeyRequired: false,
        requestValidator: new apigateway.RequestValidator(this, 'RequestValidator', {
          restApi: this.api,
          validateRequestBody: true,
        }),
      });

      // GET /health endpoint
      const health = this.api.root.addResource('health');
      health.addMethod('GET', new apigateway.MockIntegration({
        integrationResponses: [{
          statusCode: '200',
          responseTemplates: {
            'application/json': '{"status": "healthy"}',
          },
        }],
        requestTemplates: {
          'application/json': '{"statusCode": 200}',
        },
      }), {
        methodResponses: [{ statusCode: '200' }],
      });

      new cdk.CfnOutput(this, 'ApiUrl', {
        value: this.api.url,
        description: 'API Gateway URL',
        exportName: `${config.environmentName}-api-url`,
      });
    }

    // ========================================
    // Outputs
    // ========================================

    new cdk.CfnOutput(this, 'ResultsBucketName', {
      value: this.resultsBucket.bucketName,
      description: 'S3 bucket for summarization results',
      exportName: `${config.environmentName}-results-bucket`,
    });

    new cdk.CfnOutput(this, 'LambdaFunctionName', {
      value: summarizerFn.functionName,
      description: 'Lambda function for summarization',
      exportName: `${config.environmentName}-summarizer-function`,
    });
  }
}
Enter fullscreen mode Exit fullscreen mode

3.2 Create the scripts

Lambda function

"""
Main Lambda function for review summarization pipeline

This function orchestrates a multi-stage summarization process with A/B testing support:

Stage 1: Amazon Bedrock
    - Generates fast, general-purpose summary
    - Uses Claude or other foundation models
    - Always runs (provides baseline summary)

Stage 2: RAG Retrieval (Optional)
    - Queries OpenSearch vector index for relevant context
    - Grounds summary in factual knowledge
    - Only runs if OpenSearch is configured

Stage 3: SageMaker Refinement (Optional with A/B Testing)
    - Selects model based on A/B testing rules
    - Calls fine-tuned model for domain-specific refinement
    - Extracts sentiment and confidence scores
    - Supports gradual rollouts and canary deployments

Stage 4: Storage
    - Saves results to S3 for audit trail
    - Returns JSON response to API Gateway

The function uses AWS AppConfig for dynamic A/B testing configuration,
enabling gradual model rollouts without code changes.
"""

import json
import os
import boto3
from datetime import datetime
import traceback
from appconfig_helper import (
    select_model_for_request,
    get_bedrock_config,
    log_ab_test_assignment
)
from bedrock_client import summarize_review

# Initialize AWS clients
bedrock_runtime = boto3.client('bedrock-runtime', region_name=os.environ.get('AWS_REGION', 'us-east-1'))
sagemaker_runtime = boto3.client('sagemaker-runtime')
s3_client = boto3.client('s3')
ssm_client = boto3.client('ssm')

RESULTS_BUCKET = os.environ['RESULTS_BUCKET']
ENDPOINT_PARAMETER = os.environ['ENDPOINT_PARAMETER']
BEDROCK_MODEL_ID = os.environ.get('BEDROCK_MODEL_ID', 'anthropic.claude-v2')
OPENSEARCH_ENDPOINT = os.environ.get('OPENSEARCH_ENDPOINT', 'none')


def handler(event, context):
    """
    Main handler for summarization requests

    Expected input:
    {
        "text": "Review text here...",
        "options": {
            "include_sentiment": true,
            "use_rag": true
        }
    }
    """
    try:
        # Parse input
        if 'body' in event:
            body = json.loads(event['body'])
        else:
            body = event

        text = body.get('text', '')
        options = body.get('options', {})

        # Extract request context for A/B testing
        request_context = {
            'category': body.get('category', 'general'),
            'userTier': body.get('userTier', 'standard'),
            'textLength': len(text),
            'userId': body.get('userId'),
            'requestId': context.request_id if hasattr(context, 'request_id') else datetime.now().isoformat(),
        }

        if not text:
            return {
                'statusCode': 400,
                'body': json.dumps({'error': 'Missing required field: text'})
            }

        request_id = request_context['requestId']

        # Step 1: Get initial summary from Bedrock using Converse API
        print(f"[{request_id}] Step 1: Calling Bedrock for initial summary")
        bedrock_config = get_bedrock_config()

        bedrock_response = summarize_review(
            text=text,
            model_id=bedrock_config.get('modelId', BEDROCK_MODEL_ID),
            max_tokens=bedrock_config.get('maxTokens', 200),
            temperature=bedrock_config.get('temperature', 0.5)
        )

        initial_summary = bedrock_response['text']

        # Log token usage
        usage = bedrock_response['usage']
        print(f"Bedrock usage - Input: {usage['inputTokens']}, Output: {usage['outputTokens']}")

        # Step 2: RAG retrieval (optional)
        context_text = ""
        if options.get('use_rag', False) and OPENSEARCH_ENDPOINT != 'none':
            print(f"[{request_id}] Step 2: Retrieving context from OpenSearch")
            context_text = retrieve_context(text)
        else:
            print(f"[{request_id}] Step 2: Skipping RAG (disabled or not configured)")

        # Step 3: Select model using A/B testing
        print(f"[{request_id}] Step 3: Selecting model via A/B testing")
        endpoint_name = select_model_for_request(request_context)

        # Log A/B test assignment
        log_ab_test_assignment(request_id, endpoint_name)

        # Step 4: Refine with fine-tuned model (if endpoint exists)
        final_summary = initial_summary
        sentiment = "neutral"
        confidence = 0.0

        if endpoint_name and endpoint_name != 'none' and endpoint_name != 'ensemble':
            print(f"[{request_id}] Step 4: Refining with SageMaker endpoint: {endpoint_name}")
            refinement = refine_with_sagemaker(
                endpoint_name=endpoint_name,
                summary=initial_summary,
                context=context_text,
                original_text=text
            )
            final_summary = refinement.get('summary', initial_summary)
            sentiment = refinement.get('sentiment', 'neutral')
            confidence = refinement.get('confidence', 0.0)
        elif endpoint_name == 'ensemble':
            print(f"[{request_id}] Step 4: Using multi-model ensemble")
            # TODO: Implement ensemble logic
            final_summary = initial_summary
            sentiment = "neutral"
            confidence = 0.0
        else:
            print(f"[{request_id}] Step 4: Skipping SageMaker refinement (no endpoint configured)")

        # Step 5: Store results
        result = {
            'request_id': request_id,
            'timestamp': datetime.now().isoformat(),
            'initial_summary': initial_summary,
            'final_summary': final_summary,
            'sentiment': sentiment,
            'confidence': confidence,
            'used_rag': options.get('use_rag', False) and OPENSEARCH_ENDPOINT != 'none',
            'model_endpoint': endpoint_name,
            'request_context': request_context,
        }

        # Save to S3
        s3_key = f"results/{datetime.now().strftime('%Y/%m/%d')}/{request_id}.json"
        s3_client.put_object(
            Bucket=RESULTS_BUCKET,
            Key=s3_key,
            Body=json.dumps(result, indent=2),
            ContentType='application/json'
        )

        print(f"[{request_id}] Complete. Results saved to s3://{RESULTS_BUCKET}/{s3_key}")

        return {
            'statusCode': 200,
            'headers': {
                'Content-Type': 'application/json',
                'Access-Control-Allow-Origin': '*'
            },
            'body': json.dumps(result)
        }

    except Exception as e:
        print(f"Error: {str(e)}")
        print(traceback.format_exc())
        return {
            'statusCode': 500,
            'body': json.dumps({
                'error': str(e),
                'traceback': traceback.format_exc()
            })
        }


def retrieve_context(query: str, top_k: int = 3) -> str:
    """
    Retrieve relevant context from OpenSearch
    TODO: Implement OpenSearch vector search
    """
    # Placeholder - implement OpenSearch integration
    return ""


def refine_with_sagemaker(endpoint_name: str, summary: str, context: str, original_text: str) -> dict:
    """
    Refine summary and extract sentiment using fine-tuned SageMaker model
    """
    try:
        # Send original text to the model for summarization
        payload = {
            "inputs": original_text
        }

        response = sagemaker_runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType='application/json',
            Body=json.dumps(payload)
        )

        result = json.loads(response['Body'].read().decode())

        # Extract the summary from the model's response
        refined_summary = result.get('summary', summary)

        return {
            'summary': refined_summary,
            'sentiment': 'neutral',  # TODO: Add sentiment analysis
            'confidence': 0.0
        }

    except Exception as e:
        print(f"SageMaker error: {str(e)}")
        return {
            'summary': summary,
            'sentiment': 'neutral',
            'confidence': 0.0
        }
Enter fullscreen mode Exit fullscreen mode

Script to test the endpoint

#!/bin/bash
# Test script for the news summarization API

set -e

# Get API URL from CloudFormation stack
STACK_NAME="${1:-inference-pipeline-kate}"

echo "Getting API URL from stack: $STACK_NAME"
API_URL=$(aws cloudformation describe-stacks \
  --stack-name "$STACK_NAME" \
  --query 'Stacks[0].Outputs[?OutputKey==`ApiUrl`].OutputValue' \
  --output text)

if [ -z "$API_URL" ]; then
  echo "Error: Could not find API URL in stack outputs"
  exit 1
fi

echo "API URL: $API_URL"
echo ""

# Test 1: Health check
echo "Test 1: Health Check"
echo "===================="
curl -s "${API_URL}health" | jq .
echo -e "\n"

# Test 2: Technology news article
echo "Test 2: Technology News Article"
echo "================================"
curl -s -X POST "${API_URL}summarize" \
  -H "Content-Type: application/json" \
  -d '{
    "text": "Apple Inc. announced today the launch of its latest iPhone model, featuring significant improvements in camera technology and battery life. The new device includes a 48-megapixel main camera, up from the previous 12-megapixel sensor, and promises up to 20 hours of video playback. The company also introduced new AI-powered features for photo editing and enhanced security measures. Pre-orders begin next Friday, with the device hitting stores two weeks later. Industry analysts predict strong sales, particularly in the premium smartphone segment. The starting price is set at $999 for the base model.",
    "options": {
      "use_rag": false
    }
  }' | jq .
echo -e "\n"

# Test 3: Political news article
echo "Test 3: Political News Article"
echo "==============================="
curl -s -X POST "${API_URL}summarize" \
  -H "Content-Type: application/json" \
  -d '{
    "text": "The Senate voted 65-35 today to pass a comprehensive infrastructure bill worth $1.2 trillion. The bipartisan legislation includes funding for roads, bridges, public transit, and broadband internet expansion. Supporters argue the bill will create millions of jobs and modernize aging infrastructure. Critics express concerns about the cost and potential impact on the federal deficit. The bill now moves to the House of Representatives for consideration. President Biden praised the Senate vote, calling it a historic investment in America future. The legislation has been in negotiation for months.",
    "options": {
      "use_rag": false
    }
  }' | jq .
echo -e "\n"

# Test 4: Business news article
echo "Test 4: Business News Article"
echo "=============================="
curl -s -X POST "${API_URL}summarize" \
  -H "Content-Type: application/json" \
  -d '{
    "text": "Tesla reported record quarterly earnings today, beating Wall Street expectations. The electric vehicle maker delivered 250,000 vehicles in the quarter, a 40% increase from the same period last year. Revenue reached $13.8 billion, up from $10.4 billion a year ago. CEO Elon Musk attributed the strong performance to increased production capacity and growing demand for electric vehicles globally. The company also announced plans to build two new manufacturing facilities in Europe and Asia. Tesla stock rose 8% in after-hours trading following the earnings announcement.",
    "options": {
      "use_rag": false
    }
  }' | jq .
echo -e "\n"

# Test 5: Sports news article
echo "Test 5: Sports News Article"
echo "============================"
curl -s -X POST "${API_URL}summarize" \
  -H "Content-Type: application/json" \
  -d '{
    "text": "In a thrilling championship game, the Lakers defeated the Celtics 108-105 to win their 18th NBA title. LeBron James led the team with 32 points, 11 rebounds, and 8 assists in what many are calling one of the greatest performances in Finals history. The victory came after the Lakers trailed by 15 points in the third quarter. Anthony Davis contributed 28 points and played crucial defense in the final minutes. This marks the Lakers first championship in over a decade. Head coach Frank Vogel praised the team resilience and determination throughout the playoffs.",
    "options": {
      "use_rag": false
    }
  }' | jq .
echo -e "\n"

echo "All tests completed!"
Enter fullscreen mode Exit fullscreen mode

Deploy the app

1. Deploy the resource on AWS

cdk deploy --all
It will deploy three stacks: AppConfigStack, TrainingPipelineStack and InferencePipelineStack

2. Download the training data

python3 scripts/download_training_data.py \
  --source huggingface \
  --dataset cnn_dailymail \
  --num-samples 5000
Enter fullscreen mode Exit fullscreen mode

3. Upload the datasets to S3 bucket

python3 scripts/upload_training_data.py 
Enter fullscreen mode Exit fullscreen mode

4. Start training job

python3 scripts/start_training.py \        
  --model-name t5-base \
  --epochs 5 \
  --batch-size 4 \
  --instance-type ml.g4dn.xlarge
Enter fullscreen mode Exit fullscreen mode

The script using LoRA for fine-tuning default, if you do want to full fine-tuning explicitly put it in the command

python3 scripts/start_training.py --no-lora
Enter fullscreen mode Exit fullscreen mode

5. Get the metrics

# Download metrics from S3
# Get job name from previous step
JOB_NAME="review-summarizer-kate-xxxxxx-xxxxxx"


MODEL_BUCKET=$(aws cloudformation describe-stacks \
 --stack-name training-pipeline-kate \
 --query 'Stacks[0].Outputs[?OutputKey==`ModelBucketName`].OutputValue' \
 --output text)


aws s3 cp s3://$MODEL_BUCKET/models/$JOB_NAME/output/output.tar.gz .
tar -xzf output.tar.gz


# View metrics
cat metrics.json
Enter fullscreen mode Exit fullscreen mode

The result will look like

{
  "validation_rouge_l": 0.2694268479883026,
  "test_rouge_l": 0.27888068965217705,
  "final_train_loss": 0.8831174189448356,
  "use_lora": true,
  "trainable_params": 884736,
  "model_name": "t5-base",
  "epochs": 5,
  "batch_size": 4,
  "learning_rate": 3e-05,
  "lora_config": {
    "r": 8,
    "alpha": 32,
    "dropout": 0.1
  } 
}
Enter fullscreen mode Exit fullscreen mode

Depending on the result, you can choose to adjust the parameter of the training script to get better results. For example, changing pre-training models or getting more training data.

6. Register Model

If you are happy with the result, register the model and wait for approval

# Create model package
aws sagemaker create-model-package \
 --model-package-group-name "review-summarizer" \
 --model-package-description "Fine-tuned T5 for review summarization" \
 --inference-specification '{
   "Containers": [{
     "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310",
     "ModelDataUrl": "s3://'"$MODEL_BUCKET"'/models/'"$JOB_NAME"'/output/model.tar.gz"
   }],
   "SupportedContentTypes": ["application/json"],
   "SupportedResponseMIMETypes": ["application/json"]
 }' \
 --model-approval-status "PendingManualApproval"
Enter fullscreen mode Exit fullscreen mode

7. Approve the model package

# Get model package ARN from previous step
aws sagemaker list-model-packages --model-package-group-name "review-summarizer"


MODEL_PACKAGE_ARN="arn:aws:sagemaker:ap-southeast-2:123456789012:model-package/review-summarizer/1"


# Approve model
aws sagemaker update-model-package \
 --model-package-arn $MODEL_PACKAGE_ARN \
 --model-approval-status "Approved"
Enter fullscreen mode Exit fullscreen mode

This will automatically trigger a lambda function to create SageMaker endpoint and update Parameter Store with the new endpoint

8. Test the api

./scripts/test_api.sh
Enter fullscreen mode Exit fullscreen mode

Result

Test 1: Health Check
====================
{
  "status": "healthy"
}


Test 2: Technology News Article
================================
{
  "request_id": "2026-02-15T10:33:26.300629",
  "timestamp": "2026-02-15T10:33:30.740083",
  "initial_summary": "Here is a concise summary of the customer review:\n\nThe new iPhone model features significant upgrades, including a 48-megapixel main camera and up to 20 hours of video playback. It also includes new AI-powered photo editing features and enhanced security measures. Pre-orders begin next Friday, with the device launching two weeks later. Industry analysts predict strong sales, particularly in the premium smartphone segment, with a starting price of $999 for the base model.",
  "final_summary": "the new iPhone features a 48-megapixel main camera and 20 hours of video playback. the company also introduced new AI-powered features for photo editing. Industry analysts predict strong sales, particularly in the premium smartphone segment.",
  "sentiment": "neutral",
  "confidence": 0.0,
  "used_rag": false,
  "model_endpoint": "endpoint-kate",
  "request_context": {
    "category": "general",
    "userTier": "standard",
    "textLength": 602,
    "userId": null,
    "requestId": "2026-02-15T10:33:26.300629"
  }
}


Test 3: Political News Article
===============================
{
  "request_id": "2026-02-15T10:33:30.933191",
  "timestamp": "2026-02-15T10:33:34.585612",
  "initial_summary": "Here is a concise, objective summary of the customer review:\n\nThe Senate passed a $1.2 trillion bipartisan infrastructure bill that includes funding for roads, bridges, public transit, and broadband. Supporters say it will create jobs and modernize infrastructure, while critics are concerned about the cost and impact on the federal deficit. The bill now goes to the House for consideration, and President Biden praised the Senate's historic vote.",
  "final_summary": "the bill includes funding for roads, bridges, public transit, and broadband internet expansion. President Biden calls the vote a historic investment in America future.",
  "sentiment": "neutral",
  "confidence": 0.0,
  "used_rag": false,
  "model_endpoint": "endpoint-kate",
  "request_context": {
    "category": "general",
    "userTier": "standard",
    "textLength": 598,
    "userId": null,
    "requestId": "2026-02-15T10:33:30.933191"
  }
}


Test 4: Business News Article
==============================
{
  "request_id": "2026-02-15T10:33:34.705612",
  "timestamp": "2026-02-15T10:33:38.517618",
  "initial_summary": "Here is a concise, objective summary of the customer review:\n\nTesla reported record quarterly earnings, beating Wall Street expectations. The company delivered 250,000 vehicles, a 40% increase from the previous year, and revenue reached $13.8 billion. CEO Elon Musk attributed the strong performance to increased production capacity and growing global demand for electric vehicles. Tesla also announced plans to build two new manufacturing facilities in Europe and Asia, and the stock price rose 8% after the earnings announcement.",
  "final_summary": "Tesla delivered 250,000 vehicles in the quarter, a 40% increase from the same period last year. Revenue reached $13.8 billion, up from $10.4 billion a year ago.",
  "sentiment": "neutral",
  "confidence": 0.0,
  "used_rag": false,
  "model_endpoint": "endpoint-kate",
  "request_context": {
    "category": "general",
    "userTier": "standard",
    "textLength": 570,
    "userId": null,
    "requestId": "2026-02-15T10:33:34.705612"
  }
}


Test 5: Sports News Article
============================
{
  "request_id": "2026-02-15T10:33:38.619204",
  "timestamp": "2026-02-15T10:33:42.190132",
  "initial_summary": "In a closely contested NBA Finals, the Los Angeles Lakers defeated the Boston Celtics 108-105 to win their 18th championship. LeBron James delivered an outstanding performance with 32 points, 11 rebounds, and 8 assists, while Anthony Davis added 28 points and played strong defense in the closing minutes. The Lakers overcame a 15-point deficit in the third quarter to secure the victory, showcasing their resilience and determination throughout the playoffs, as praised by head coach Frank Vogel.",
  "final_summary": "LeBron James led the team with 32 points, 11 rebounds, and 8 assists. This is the Lakers first championship in over a decade.",
  "sentiment": "neutral",
  "confidence": 0.0,
  "used_rag": false,
  "model_endpoint": "endpoint-kate",
  "request_context": {
    "category": "general",
    "userTier": "standard",
    "textLength": 563,
    "userId": null,
    "requestId": "2026-02-15T10:33:38.619204"
  }
}


All tests completed!
Enter fullscreen mode Exit fullscreen mode

Now we have a complete fine-tuning pipeline with automatic model deployment. The application automatically uses the latest approved Amazon SageMaker endpoint for inference.
In addition, we can integrate Retrieval-Augmented Generation (RAG) into the pipeline. This involves setting up Amazon OpenSearch as a vector database, embedding relevant documents, and updating the Lambda function to retrieve contextual information before generating summaries (refer to this).
Currently, the system immediately switches to the new model once approved. However, we can implement A/B testing to gradually roll out the model, reducing potential risks and ensuring smoother transitions.
Link to the repo

Top comments (0)