DEV Community

Cover image for Deploying whisperX on AWS SageMaker as Asynchronous Endpoint
Mohamad Albaker Kawtharani
Mohamad Albaker Kawtharani

Posted on

Deploying whisperX on AWS SageMaker as Asynchronous Endpoint

Whisper is an automatic speech recognition system developed by OpenAI, designed to transcribe and translate audio into text across multiple languages, focusing on general-purpose transcription tasks.
WhisperX, on the other hand, extends Whisper's capabilities with enhancements like faster processing times and additional features such as voice activity detection (VAD), making it more suitable for specific applications that require these advanced functionalities.

First, we install the Hugging Face Hub library, enabling interaction with the Hugging Face Model Hub from our environment.

!pip install huggingface_hub
Enter fullscreen mode Exit fullscreen mode

We import necessary modules and define two functions. download_hf_model downloads a specified model from Hugging Face to a local directory. fetch_models downloads the WhisperX and Voice Activity Detection (VAD) models, storing them locally for use.

import huggingface_hub
import os
import urllib.request

def download_hf_model(model_name: str, hf_token: str, local_model_dir: str) -> str:
    """
    Fetches the provided model from HuggingFace and returns the subdirectory it is downloaded to.
    """
    model_subdir = model_name.split('@')[0]
    huggingface_hub.snapshot_download(model_subdir, token=hf_token, local_dir=f"{local_model_dir}/{model_subdir}", local_dir_use_symlinks=False)
    return model_subdir

def fetch_models(hf_token: str, local_model_dir="./models"):
    """
    Fetches models required for WhisperX transcription without diarization.
    """
    WHISPERX_MODEL = "guillaumekln/faster-whisper-large-v2"
    VAD_MODEL_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"

    # Fetch WhisperX model
    download_hf_model(model_name=WHISPERX_MODEL, hf_token=hf_token, local_model_dir=local_model_dir)

    # Fetch VAD Segmentation model
    vad_model_dir = "whisperx/vad"
    if not os.path.exists(f"{local_model_dir}/{vad_model_dir}"):
        os.makedirs(f"{local_model_dir}/{vad_model_dir}")
    urllib.request.urlretrieve(VAD_MODEL_URL, f"{local_model_dir}/{vad_model_dir}/pytorch_model.bin")
Enter fullscreen mode Exit fullscreen mode

Then we call fetch_models function to download the WhisperX and VAD models into our local directory named using the specified Hugging Face token.

fetch_models(
    hf_token="", # enter your hugging face token
    local_model_dir="./models-v1"
)
Enter fullscreen mode Exit fullscreen mode

Here, we set up necessary files for our model. We create inference.py and requirements.txt inside the /code in our directory.

inference.py contains the inference script for loading the WhisperX model, processing input audio files from S3, performing transcription, and formatting the output.
requirements.txt specifies the Python packages needed for the inference environment.

import os

# Directory and file paths
dir_path = './models-v1'
inference_file_path = os.path.join(dir_path, 'code/inference.py')
requirements_file_path = os.path.join(dir_path, 'code/requirements.txt')

# Create the directory structure
os.makedirs(os.path.dirname(inference_file_path), exist_ok=True)

# Inference.py content
inference_content = '''# inference.py
# inference.py
import io
import json
import logging
import os
import tempfile
import time
import boto3
import torch
import whisperx

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
s3 = boto3.client('s3')

def model_fn(model_dir, context=None):
    """
    Load and return the WhisperX model necessary for audio transcription.
    """
    print("Entering model_fn")

    logging.info("Loading WhisperX model")
    model = whisperx.load_model(whisper_arch=f"{model_dir}/guillaumekln/faster-whisper-large-v2",
                                device=DEVICE,
                                language="en",
                                compute_type="float16",
                                vad_options={'model_fp': f"{model_dir}/whisperx/vad/pytorch_model.bin"})
    print("Loaded WhisperX model")

    print("Exiting model_fn with model loaded")
    return {
        'model': model
    }

def input_fn(request_body, request_content_type):
    """
    Process and load audio from S3, given the request body containing S3 bucket and key.
    """
    print("Entering input_fn")
    if request_content_type != 'application/json':
        raise ValueError("Invalid content type. Must be application/json")

    request = json.loads(request_body)
    s3_bucket = request['s3bucket']
    s3_key = request['s3key']

    # Download the file from S3
    temp_file = tempfile.NamedTemporaryFile(delete=False)
    s3.download_file(Bucket=s3_bucket, Key=s3_key, Filename=temp_file.name)
    print(f"Downloaded audio from S3: {s3_bucket}/{s3_key}")

    print("Exiting input_fn")
    return temp_file.name

def predict_fn(input_data, model, context=None):
    """
    Perform transcription on the provided audio file and delete the file afterwards.
    """
    print("Entering predict_fn")
    start_time = time.time()

    whisperx_model = model['model']

    logging.info("Loading audio")
    audio = whisperx.load_audio(input_data)

    logging.info("Transcribing audio")
    transcription_result = whisperx_model.transcribe(audio, batch_size=16)

    try:
        os.remove(input_data)  # input_data contains the path to the temp file
        print(f"Temporary file {input_data} deleted.")
    except OSError as e:
        print(f"Error: {input_data} : {e.strerror}")

    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info(f"Transcription took {int(elapsed_time)} seconds")

    print(f"Exiting predict_fn, processing took {int(elapsed_time)} seconds")
    return transcription_result

def output_fn(prediction, accept, context=None):
    """
    Prepare the prediction result for the response.
    """
    print("Entering output_fn")
    if accept != "application/json":
        raise ValueError("Accept header must be application/json")
    response_body = json.dumps(prediction)
    print("Exiting output_fn with response prepared")
    return response_body, accept

'''

# Write the inference.py file
with open(inference_file_path, 'w') as file:
    file.write(inference_content)

# Requirements.txt content
requirements_content = '''speechbrain==0.5.16
faster-whisper==0.7.1
git+https://github.com/m-bain/whisperx.git@1b092de19a1878a8f138f665b1467ca21b076e7e
ffmpeg-python
'''

# Write the requirements.txt file
with open(requirements_file_path, 'w') as file:
    file.write(requirements_content)
Enter fullscreen mode Exit fullscreen mode

Then we compresses our model directory into a gzip-compressed tar. The archive is created in the current working directory, encapsulating the prepared model and code necessary for deploying WhisperX on SageMaker.

import shutil
shutil.make_archive('./modelv1', 'gztar', './models-v1')
Enter fullscreen mode Exit fullscreen mode

After that we upload the compressed model archive to an S3 bucket using. It also saves the S3 path of the uploaded model.

import sagemaker
import boto3
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
bucket = '' # Enter your s3 bucket name
prefix = 'whisperx/code'

# Upload the model to S3
s3_path = sagemaker_session.upload_data(
    'modelv1.tar.gz',
    bucket=bucket,
    key_prefix=prefix
)

print(f"Model uploaded to {s3_path}")
Enter fullscreen mode Exit fullscreen mode

We are ready to create and deploy the sagemaker model on SageMaker as asynchronous endpoint using a image URI and the model data uploaded to S3.
For different deployment regions and requirements, find suitable container images at AWS Deep Learning Container.

import sagemaker
from sagemaker.model import Model
from sagemaker.async_inference import AsyncInferenceConfig

# Initialize a sagemaker session
sagemaker_session = sagemaker.Session()

# Create a SageMaker model
model = Model(
    image_uri=image_uri,
    role=role,
    model_data=s3_path,
)

# Specify the output location
async_config = AsyncInferenceConfig(
    output_path='s3://{your_s3_bucket_name}/whisperx/output'
) # add your s3 bucket name to have the output there

# Deploy the model to an asynchronous endpoint
predictor = model.deploy(
    initial_instance_count=1,
    instance_type='ml.g4dn.xlarge', # you can change
    async_inference_config=async_config
)
Enter fullscreen mode Exit fullscreen mode

Now as we are done on the deployment we can invoke the endpoint.
This function, invoke_async_model, is developed for asynchronous inference. It takes parameters for an S3 bucket and key, saves a JSON payload containing these details to S3, and invokes a SageMaker asynchronous endpoint with the location of this payload.
The function aims to facilitate the processing of inference jobs without requiring the client to wait for the job to complete, ideal for large-scale or batch processing tasks.
Upon successful invocation, it prints confirmation and the unique InferenceId provided by SageMaker, which can be used to track the status and result of the inference job.

import boto3
import json

def invoke_async_model(s3_bucket, s3_key):
    """
    Saves a JSON payload to S3 and invokes a SageMaker asynchronous endpoint with the payload.

    Parameters:
    - s3_bucket: The S3 bucket name.
    - s3_key: The S3 key for the input file.
    Returns:
    The response from the endpoint invocation, including the InferenceId.
    """
    s3 = boto3.client('s3')
    sagemaker_runtime = boto3.client('sagemaker-runtime')

    # Create the payload
    payload = {
        "s3bucket": s3_bucket,
        "s3key": s3_key
    }

    # Define the S3 key for the input JSON
    s3_key_for_input = s3_key.rsplit('/', 1)[0] + '/asynch_input_file.json'

    # Save the payload to S3 as a JSON file
    try:
        s3.put_object(
            Body=json.dumps(payload),
            Bucket=s3_bucket,
            Key=s3_key_for_input
        )
        print("Payload saved to S3.")
    except Exception as e:
        print(f"Error saving JSON to S3: {e}")
        return

    # The S3 location of the input data for the inference request
    input_location = f"s3://{s3_bucket}/{s3_key_for_input}"

    endpoint_name = '' ## add you endpoint name
    # Invoke the SageMaker asynchronous endpoint
    try:
        response = sagemaker_runtime.invoke_endpoint_async(
            EndpointName=endpoint_name,
            InputLocation=input_location,
            ContentType='application/json'
        )
        print(f"Endpoint invoked. InferenceId: {response['InferenceId']}")
    except Exception as e:
        print(f"Error invoking endpoint: {e}")
        return

    return response

# Example usage
request_body = {
    's3bucket': '', # add the s3 bucket 
    's3key':'', # the audio / video file prefix
}

# Invoke the function with parameters from request_body
invoke_async_model(
    s3_bucket=request_body['s3bucket'],
    s3_key=request_body['s3key'],
)

Enter fullscreen mode Exit fullscreen mode

I hope this was helpful.
Please let me know if you have any questions!

Top comments (0)