AWS SageMaker agent example#

Deploy and serve an XGBoost model on AWS SageMaker using FastAPI#

This example demonstrates how to deploy and serve an XGBoost model on SageMaker using FastAPI custom inference.

We train an XGBoost model on the Pima Indians Diabetes dataset and generate a tar.gz file to be stored in an S3 bucket.


The model artifact needs to be available in an S3 bucket for SageMaker to be able to access.

import os
import tarfile

import flytekit
from flytekit import ImageSpec, task, workflow
from flytekit.types.file import FlyteFile
from numpy import loadtxt
from sklearn.model_selection import train_test_split

train_model_image = ImageSpec(

if train_model_image.is_container():
    from xgboost import XGBClassifier

def train_model(dataset: FlyteFile) -> FlyteFile:
    dataset = loadtxt(, delimiter=",")
    X = dataset[:, 0:8]
    Y = dataset[:, 8]
    X_train, _, y_train, _ = train_test_split(X, Y, test_size=0.33, random_state=7)

    model = XGBClassifier(), y_train)

    serialized_model = os.path.join(flytekit.current_context().working_directory, "xgboost_model.json")
    booster = model.get_booster()

    return FlyteFile(path=serialized_model)

def convert_to_tar(model: FlyteFile) -> FlyteFile:
    tf ="model.tar.gz", "w:gz")
    tf.add(, arcname="xgboost_model")

    return FlyteFile("model.tar.gz")

def sagemaker_xgboost_wf(
    dataset: FlyteFile = "",
) -> FlyteFile:
    serialized_model = train_model(dataset=dataset)
    return convert_to_tar(model=serialized_model)


Replace with a container registry to which you can publish. To upload the image to the local registry in the demo cluster, indicate the registry as localhost:30000.

The above workflow generates a compressed model artifact that can be stored in an S3 bucket. Take note of the S3 URI.

To deploy the model on SageMaker, use the awssagemaker_inference.create_sagemaker_deployment function.

from flytekit import kwtypes
from flytekitplugins.awssagemaker_inference import create_sagemaker_deployment

REGION = "us-east-2"
S3_OUTPUT_PATH = "s3://sagemaker-agent-xgboost/inference-output/output"
NEW_DEPLOYMENT_NAME = "xgboost-fastapi-{idempotence_token}"
EXISTING_DEPLOYMENT_NAME = "xgboost-fastapi-{inputs.idempotence_token}"

sagemaker_image = ImageSpec(
    registry="",  # Amazon EC2 Container Registry or a Docker registry accessible from your VPC.
    packages=["xgboost", "fastapi", "uvicorn", "scikit-learn"],
).with_commands(["chmod +x /root/serve"])

sagemaker_deployment_wf = create_sagemaker_deployment(
    model_input_types=kwtypes(model_path=str, execution_role_arn=str),
        "ModelName": NEW_DEPLOYMENT_NAME,
        "PrimaryContainer": {
            "Image": "{images.primary_container_image}",
            "ModelDataUrl": "{inputs.model_path}",
        "ExecutionRoleArn": "{inputs.execution_role_arn}",
        "EndpointConfigName": NEW_DEPLOYMENT_NAME,
        "ProductionVariants": [
                "VariantName": "variant-name-1",
                "ModelName": EXISTING_DEPLOYMENT_NAME,
                "InitialInstanceCount": 1,
                "InstanceType": "{inputs.instance_type}",
        "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": S3_OUTPUT_PATH}},
        "EndpointName": NEW_DEPLOYMENT_NAME,
        "EndpointConfigName": EXISTING_DEPLOYMENT_NAME,
    images={"primary_container_image": sagemaker_image},

This function returns an imperative workflow responsible for deploying the XGBoost model, creating an endpoint configuration and initializing an endpoint. Configurations relevant to these tasks are passed to the awssagemaker_inference.create_sagemaker_deployment function.

An idempotence token ensures the generation of unique tokens for each configuration, preventing name collisions during updates.

  • idempotence_token represents the configuration hash.

  • inputs.idempotence_token refers to the idempotence token from the previous task. The workflow injects idempotence token from the previous task into the current task as an input.

sagemaker_image should include the inference code, necessary libraries, and an entrypoint for model serving.


For more detailed instructions on using your custom inference image, refer to the Amazon SageMaker documentation.

If the plugin attempts to create a deployment that already exists, it will return the existing ARNs instead of raising an error.

To receive inference requests, the container built with sagemaker_image must have a web server listening on port 8080 and must accept POST and GET requests to the /invocations and /ping endpoints, respectively.

We define the FastAPI inference code as follows:

from contextlib import asynccontextmanager
from datetime import datetime

import numpy as np
from fastapi import FastAPI, Request, Response, status

if sagemaker_image.is_container():
    from xgboost import Booster, DMatrix

class Predictor:
    def __init__(self, path: str, name: str):
        self._model = Booster()
        self._model.load_model(os.path.join(path, name))

    def predict(self, inputs: DMatrix) -> np.ndarray:
        return self._model.predict(inputs)

ml_model: Predictor = None

async def lifespan(app: FastAPI):
    global ml_model
    path = os.getenv("MODEL_PATH", "/opt/ml/model")
    ml_model = Predictor(path=path, name="xgboost_model")
    ml_model = None

app = FastAPI(lifespan=lifespan)

async def ping():
    return Response(content="OK", status_code=200)"/invocations")
async def invocations(request: Request):
    print(f"Received request at {}")

    json_payload = await request.json()

    X_test = DMatrix(np.array(json_payload).reshape((1, -1)))
    y_test = ml_model.predict(X_test)

    response = Response(
    return response

Create a file named serve to serve the model. In our case, we are using FastAPI:


_term() {
echo "Caught SIGTERM signal!"
kill -TERM "$child" 2>/dev/null

trap _term SIGTERM

echo "Starting the API server"
uvicorn sagemaker_inference_agent_example_usage:app --host --port 8080&

wait "$child"

You can trigger the sagemaker_deployment_wf by providing the model artifact path, execution role ARN, and instance type.

Once the endpoint creation status changes to InService, the SageMaker deployment workflow succeeds. You can then invoke the endpoint using the SageMaker agent as follows:

from flytekitplugins.awssagemaker_inference import SageMakerInvokeEndpointTask

invoke_endpoint = SageMakerInvokeEndpointTask(
        "EndpointName": "YOUR_ENDPOINT_NAME_HERE",
        "InputLocation": "s3://sagemaker-agent-xgboost/inference_input",

The awssagemaker_inference.SageMakerInvokeEndpointTask invokes an endpoint asynchronously, resulting in an S3 location that will be populated with the output after it’s generated. For instance, the inference_input file may include input like this: [6, 148, 72, 35, 0, 33.6, 0.627, 50]

To delete the deployment, you can instantiate a awssagemaker_inference.delete_sagemaker_deployment function.

from flytekitplugins.awssagemaker_inference import delete_sagemaker_deployment

sagemaker_deployment_deletion_wf = delete_sagemaker_deployment(name="sagemaker-deployment-deletion", region="us-east-2")

def deployment_deletion_workflow():

You need to provide the endpoint name, endpoint config name, and the model name to execute this deletion, which removes the endpoint, endpoint config, and the model.

Available tasks#

You have the option to execute the SageMaker tasks independently. The following tasks are available for use:

All tasks except the awssagemaker_inference.SageMakerEndpointTask inherit the awssagemaker_inference.BotoTask. The awssagemaker_inference.BotoTask provides the flexibility to invoke any Boto3 method. If you need to interact with the Boto3 APIs, you can use this task.