py
import boto3
import mlflow
import os
import shutil
# Set up S3 client
s3 = boto3.client("s3")
# MLflow Tracking URI
MLFLOW_TRACKING_URI = "http://your-mlflow-server"
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
# Your S3 bucket details
S3_BUCKET = "your-s3-bucket-name"
S3_DESTINATION_PREFIX = "mlflow-artifacts/"
def get_prod_model():
"""Fetch the model version marked as 'prod'."""
client = mlflow.tracking.MlflowClient()
models = client.search_model_versions("name='your-model-name'")
for model in models:
if model.current_stage.lower() == "production": # 'prod' alias maps to 'production' stage
return model
return None
def download_and_upload_artifacts(model):
"""Download model artifacts and upload to S3."""
local_path = f"/tmp/{model.version}"
os.makedirs(local_path, exist_ok=True)
artifact_uri = model.source
mlflow.artifacts.download_artifacts(artifact_uri, local_path)
for root, _, files in os.walk(local_path):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = f"{S3_DESTINATION_PREFIX}{model.version}/{file}"
s3.upload_file(local_file_path, S3_BUCKET, s3_key)
shutil.rmtree(local_path)
def test_lambda_function():
"""Test the Lambda function in Jupyter notebook."""
model = get_prod_model()
if not model:
print("No model found with alias 'prod'")
return
print(f"Found model version: {model.version}")
download_and_upload_artifacts(model)
print(f"Artifacts for model version {model.version} uploaded successfully to S3.")
# Run the test
test_lambda_function()
Real challenges. Real solutions. Real talk.
From technical discussions to philosophical debates, AWS and AWS Partners examine the impact and evolution of gen AI.
For further actions, you may consider blocking this person and/or reporting abuse
Top comments (0)