SetFit is an efficient and prompt-free framework for few-shot fine-tuning of Sentence Transformers. It achieves high accuracy with little labeled data and it is widely used as text classifier in NLP tasks.
However, SetFit’s compressed format differs from PyTorch, which typically uses a .bin
file. Instead, SetFit uses the .safetensors
format, making it incompatible with AWS SageMaker’s default images. In this article, we present a method to deploy SetFit — or any model framework using a non-standard compression format — while ensuring a straightforward model invocation process.
The AWS services utilized in this deployment include Lambda, ECR, and S3. The design diagram is as follows.
Step 1: Compress the Fine-tuned SetFit Model
Here, we assume you have used the trainer to fine-tune your pre-trained SetFit model. For details on fine-tuning a SetFit model, refer to this resource. After fine-tuning, run the following python script to save the model.
trainer.model.save_pretrained('saved_model')
Step 2: Create the Inference Image for Lambda
To create the image, gather the following four files: Dockerfile
, requirements.txt
, app.py
, and the saved_model/
from Step 1. Place them in the same directory.
Dockerfile
FROM public.ecr.aws/lambda/python:3.9
# Install dependencies
COPY requirements.txt ${LAMBDA_TASK_ROOT}
RUN pip install - no-cache-dir -r requirements.txt
# Set up the environment
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUE
# Copy the model files and app code
COPY saved_model /app/saved_model
COPY app.py ${LAMBDA_TASK_ROOT}
ENV MODEL_PATH /app/saved_model
ENV HF_HOME=/tmp/huggingface
CMD ["app.lambda_handler"]
requirements.txt
boto3==1.35.43
safetensors==0.4.5
sagemaker==2.232.2
torch==1.13.1
transformers==4.36.0
setfit==1.0.3
huggingface-hub==0.23.5
app.py
Here, we can modify the input_file_path and output_file_path. Currently, the input file is stored in an S3 bucket in .csv
format, with the column for inference named ‘query_text’.
import os
import boto3
import pandas as pd
from setfit import SetFitModel
cache_dir = "/tmp/huggingface"
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
os.environ["HF_HOME"] = cache_dir
# Initialize S3 client
s3_client = boto3.client("s3")
# Load the model
model = SetFitModel.from_pretrained(os.getenv("MODEL_PATH"))
def handle_s3_event(bucket_name, key):
# Define paths
input_file_path = "/tmp/input.csv"
output_file_path = "/tmp/output.csv"
# Download the CSV file from S3
s3_client.download_file(bucket_name, key, input_file_path)
# Load input data from CSV
input_data = pd.read_csv(input_file_path)
# Assuming input_data contains a column called 'text' for the model to process
predictions = model(input_data['query_text'].tolist())
# Convert predictions to a DataFrame and save to CSV
output_data = pd.DataFrame(predictions, columns=["prediction"])
output_data['query_text'] = input_data['query_text']
output_data.to_csv(output_file_path, index=False)
# Define the output key (location) in S3
output_key = key.replace("input", "output")
print('output key: '+output_key)
# Upload the output CSV to S3
s3_client.upload_file(output_file_path, bucket_name, output_key)
def lambda_handler(event, context):
# Extract bucket name and object key from the S3 event
bucket_name = event['Records'][0]['s3']['bucket']['name']
object_key = event['Records'][0]['s3']['object']['key']
print(f"Bucket: {bucket_name}, Key: {object_key}")
# Process the S3 event
handle_s3_event(bucket_name, object_key)
After creating all the necessary files, build the Docker image using the following command:
docker build -t setfit-model-image .
Step 3: Push the Image to ECR
To use the image in an AWS Lambda function, we need to push it to Amazon ECR. First, create a build_image.sh script in the current folder. Then, execute it using ./build_image.sh
.
build_image.sh
%%sh
# The name of our algorithm
algorithm_name=setfit-image
account=$(aws sts get-caller-identity --query Account --output text)
# Get the region defined in the current configuration
region=$(aws configure get region)
fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:v1"
# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1
if [ $? -ne 0 ]
then
aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null
fi
# Get the login command from ECR and execute it directly
$(aws ecr get-login --region ${region} --no-include-email)
# Build the docker image locally with the image name and then push it to ECR
# with the full name.
docker build -q -t ${algorithm_name} .
docker tag ${algorithm_name} ${fullname}
docker push ${fullname}
Step 4: Create a Container-based Lambda Function
After completing the previous steps, go to the AWS Console and create a new Lambda function. Here, we can use an S3 event trigger. Whenever a .csv file containing a query_text column is uploaded to the S3 bucket specified in app.py, it will automatically trigger the Lambda function. The inference results will be saved to the designated output path we defined in app.py.
Here, choose “Container image” on the Lambda function creation page. For the “Container image URI,” enter the image created in Step 3: “${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:v1”. You can find this URI in the ECR service through the AWS Console.
After creating the Lambda function, make sure to update the maximum execution time in the Configuration settings. If not adjusted, the function may time out. The maximum allowed execution time for Lambda is 15 minutes. In addition, we add an s3 bucket event trigger to the Lambda.
In this article, we demonstrated a simple approach to deploying a SetFit model using AWS Lambda, ECR, and S3. While this method is straightforward, it is limited by Lambda’s 15-minute execution time.
For workloads requiring longer inference times or greater scalability, alternative solutions such as SageMaker endpoints or ECS-based deployments may be more suitable. In future articles, we will explore these options to handle larger-scale model inference efficiently.
Top comments (0)