AWS SageMaker agent example#

Deploy and serve an XGBoost model on AWS SageMaker using FastAPI#

This example demonstrates how to deploy and serve an XGBoost model on SageMaker using FastAPI custom inference.

We train an XGBoost model on the Pima Indians Diabetes dataset and generate a tar.gz file to be stored in an S3 bucket.

Note

The model artifact needs to be available in an S3 bucket for SageMaker to be able to access.

import os
import tarfile
from pathlib import Path

import flytekit
from flytekit import ImageSpec, task, workflow
from flytekit.types.file import FlyteFile
from numpy import loadtxt
from sklearn.model_selection import train_test_split

train_model_image = ImageSpec(
    name="xgboost-train",
    registry="ghcr.io/flyteorg",
    packages=["xgboost"],
)

if train_model_image.is_container():
    from xgboost import XGBClassifier


@task(container_image=train_model_image)
def train_model(dataset: FlyteFile) -> FlyteFile:
    dataset = loadtxt(dataset.download(), delimiter=",")
    X = dataset[:, 0:8]
    Y = dataset[:, 8]
    X_train, _, y_train, _ = train_test_split(X, Y, test_size=0.33, random_state=7)

    model = XGBClassifier()
    model.fit(X_train, y_train)

    serialized_model = str(Path(flytekit.current_context().working_directory) / "xgboost_model.json")
    booster = model.get_booster()
    booster.save_model(serialized_model)

    return FlyteFile(path=serialized_model)


@task
def convert_to_tar(model: FlyteFile) -> FlyteFile:
    tf = tarfile.open("model.tar.gz", "w:gz")
    tf.add(model.download(), arcname="xgboost_model")
    tf.close()

    return FlyteFile("model.tar.gz")


@workflow
def sagemaker_xgboost_wf(
    dataset: FlyteFile = "https://dub.sh/VZrumbQ",
) -> FlyteFile:
    serialized_model = train_model(dataset=dataset)
    return convert_to_tar(model=serialized_model)

Note

Replace ghcr.io/flyteorg with a container registry to which you can publish. To upload the image to the local registry in the demo cluster, indicate the registry as localhost:30000.

The above workflow generates a compressed model artifact that can be stored in an S3 bucket. Take note of the S3 URI.

To deploy the model on SageMaker, use the awssagemaker_inference.create_sagemaker_deployment function.

from flytekit import kwtypes
from flytekitplugins.awssagemaker_inference import create_sagemaker_deployment

REGION = "us-east-2"
S3_OUTPUT_PATH = "s3://sagemaker-agent-xgboost/inference-output/output"
DEPLOYMENT_NAME = "xgboost-fastapi"

sagemaker_image = ImageSpec(
    name="sagemaker-xgboost",
    registry="ghcr.io/flyteorg",  # Amazon EC2 Container Registry or a Docker registry accessible from your VPC.
    packages=["xgboost", "fastapi", "uvicorn", "scikit-learn"],
    source_root=".",
).with_commands(["chmod +x /root/serve"])


sagemaker_deployment_wf = create_sagemaker_deployment(
    name="xgboost-fastapi",
    model_input_types=kwtypes(model_path=str, execution_role_arn=str),
    model_config={
        "ModelName": DEPLOYMENT_NAME,
        "PrimaryContainer": {
            "Image": "{images.primary_container_image}",
            "ModelDataUrl": "{inputs.model_path}",
        },
        "ExecutionRoleArn": "{inputs.execution_role_arn}",
    },
    endpoint_config_input_types=kwtypes(instance_type=str),
    endpoint_config_config={
        "EndpointConfigName": DEPLOYMENT_NAME,
        "ProductionVariants": [
            {
                "VariantName": "variant-name-1",
                "ModelName": DEPLOYMENT_NAME,
                "InitialInstanceCount": 1,
                "InstanceType": "{inputs.instance_type}",
            },
        ],
        "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": S3_OUTPUT_PATH}},
    },
    endpoint_config={
        "EndpointName": DEPLOYMENT_NAME,
        "EndpointConfigName": DEPLOYMENT_NAME,
    },
    images={"primary_container_image": sagemaker_image},
    region=REGION,
    idempotence_token=True,  # set to True by default
)

This function returns an imperative workflow responsible for deploying the XGBoost model, creating an endpoint configuration and initializing an endpoint. Configurations relevant to these tasks are passed to the awssagemaker_inference.create_sagemaker_deployment function.

An idempotence token ensures the generation of unique tokens for each configuration, preventing name collisions during updates. By default, idempotence_token in create_sagemaker_deployment is set to True, causing the agent to append an idempotence token to the model name, endpoint config name, and endpoint.

  • If a field value isn’t provided (e.g., ModelName), the agent appends the idempotence token to the workflow name and uses that as the ModelName.

  • You can also manually set the idempotence token by adding {idempotence_token} to the relevant fields in the configuration, e.g., xgboost-{idempotence_token}.

sagemaker_image should include the inference code, necessary libraries, and an entrypoint for model serving.

Note

For more detailed instructions on using your custom inference image, refer to the Amazon SageMaker documentation.

If the plugin attempts to create a deployment that already exists, it will return the existing ARNs instead of raising an error.

Note

When two executions run in parallel and attempt to create the same endpoint, one execution will proceed with creating the endpoint while both will wait until the endpoint creation process is complete.

To receive inference requests, the container built with sagemaker_image must have a web server listening on port 8080 and must accept POST and GET requests to the /invocations and /ping endpoints, respectively.

We define the FastAPI inference code as follows:

from contextlib import asynccontextmanager
from datetime import datetime

import numpy as np
from fastapi import FastAPI, Request, Response, status

if sagemaker_image.is_container():
    from xgboost import Booster, DMatrix


class Predictor:
    def __init__(self, path: str, name: str):
        self._model = Booster()
        self._model.load_model(str(Path(path) / name))

    def predict(self, inputs: DMatrix) -> np.ndarray:
        return self._model.predict(inputs)


ml_model: Predictor = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global ml_model
    path = os.getenv("MODEL_PATH", "/opt/ml/model")
    ml_model = Predictor(path=path, name="xgboost_model")
    yield
    ml_model = None


app = FastAPI(lifespan=lifespan)


@app.get("/ping")
async def ping():
    return Response(content="OK", status_code=200)


@app.post("/invocations")
async def invocations(request: Request):
    print(f"Received request at {datetime.now()}")

    json_payload = await request.json()

    X_test = DMatrix(np.array(json_payload).reshape((1, -1)))
    y_test = ml_model.predict(X_test)

    response = Response(
        content=repr(round(y_test[0])).encode("utf-8"),
        status_code=status.HTTP_200_OK,
        media_type="text/plain",
    )
    return response

Create a file named serve to serve the model. In our case, we are using FastAPI:

!/bin/bash

_term() {
echo "Caught SIGTERM signal!"
kill -TERM "$child" 2>/dev/null
}

trap _term SIGTERM

echo "Starting the API server"
uvicorn sagemaker_inference_agent_example_usage:app --host 0.0.0.0 --port 8080&

child=$!
wait "$child"

You can trigger the sagemaker_deployment_wf by providing the model artifact path, execution role ARN, and instance type.

Once the endpoint creation status changes to InService, the SageMaker deployment workflow succeeds. You can then invoke the endpoint using the SageMaker agent as follows:

from flytekitplugins.awssagemaker_inference import SageMakerInvokeEndpointTask

invoke_endpoint = SageMakerInvokeEndpointTask(
    name="sagemaker_invoke_endpoint",
    config={
        "EndpointName": "YOUR_ENDPOINT_NAME_HERE",
        "InputLocation": "s3://sagemaker-agent-xgboost/inference_input",
    },
    region=REGION,
)

The awssagemaker_inference.SageMakerInvokeEndpointTask invokes an endpoint asynchronously, resulting in an S3 location that will be populated with the output after it’s generated. For instance, the inference_input file may include input like this: [6, 148, 72, 35, 0, 33.6, 0.627, 50]

To delete the deployment, you can instantiate a awssagemaker_inference.delete_sagemaker_deployment function.

from flytekitplugins.awssagemaker_inference import delete_sagemaker_deployment

sagemaker_deployment_deletion_wf = delete_sagemaker_deployment(name="sagemaker-deployment-deletion", region="us-east-2")


@workflow
def deployment_deletion_workflow():
    sagemaker_deployment_deletion_wf(
        endpoint_name="YOUR_ENDPOINT_NAME_HERE",
        endpoint_config_name="YOUR_ENDPOINT_CONFIG_NAME_HERE",
        model_name="YOUR_MODEL_NAME_HERE",
    )

You need to provide the endpoint name, endpoint config name, and the model name to execute this deletion, which removes the endpoint, endpoint config, and the model.

Available tasks#

You have the option to execute the SageMaker tasks independently. The following tasks are available for use:

All tasks except the awssagemaker_inference.SageMakerEndpointTask inherit the awssagemaker_inference.BotoTask. The awssagemaker_inference.BotoTask provides the flexibility to invoke any Boto3 method. If you need to interact with the Boto3 APIs, you can use this task.