diff --git a/mlflow/models/container/__init__.py b/mlflow/models/container/__init__.py index efaa16f0ba292..3e55d96eaddac 100644 --- a/mlflow/models/container/__init__.py +++ b/mlflow/models/container/__init__.py @@ -38,6 +38,8 @@ DISABLE_NGINX = "DISABLE_NGINX" ENABLE_MLSERVER = "ENABLE_MLSERVER" +SERVING_ENVIRONMENT = "SERVING_ENVIRONMENT" + def _init(cmd): """ diff --git a/mlflow/sagemaker/__init__.py b/mlflow/sagemaker/__init__.py index a813cd8360468..df443bf1d2d6f 100644 --- a/mlflow/sagemaker/__init__.py +++ b/mlflow/sagemaker/__init__.py @@ -22,7 +22,7 @@ from mlflow.utils.annotations import experimental from mlflow.utils.file_utils import TempDir from mlflow.models.container import SUPPORTED_FLAVORS as SUPPORTED_DEPLOYMENT_FLAVORS -from mlflow.models.container import DEPLOYMENT_CONFIG_KEY_FLAVOR_NAME +from mlflow.models.container import DEPLOYMENT_CONFIG_KEY_FLAVOR_NAME, SERVING_ENVIRONMENT DEFAULT_IMAGE_NAME = "mlflow-pyfunc" @@ -41,6 +41,8 @@ DEFAULT_SAGEMAKER_INSTANCE_TYPE = "ml.m4.xlarge" DEFAULT_SAGEMAKER_INSTANCE_COUNT = 1 +SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker" + _logger = logging.getLogger(__name__) _full_template = "{account}.dkr.ecr.{region}.amazonaws.com/{image}:{version}" @@ -1208,7 +1210,10 @@ def _get_deployment_config(flavor_name): """ :return: The deployment configuration as a dictionary """ - deployment_config = {DEPLOYMENT_CONFIG_KEY_FLAVOR_NAME: flavor_name} + deployment_config = { + DEPLOYMENT_CONFIG_KEY_FLAVOR_NAME: flavor_name, + SERVING_ENVIRONMENT: SAGEMAKER_SERVING_ENVIRONMENT, + } return deployment_config diff --git a/tests/sagemaker/test_deployment.py b/tests/sagemaker/test_deployment.py index 28f91a75d8127..705bff2255a77 100644 --- a/tests/sagemaker/test_deployment.py +++ b/tests/sagemaker/test_deployment.py @@ -216,7 +216,7 @@ def test_attempting_to_deploy_in_asynchronous_mode_without_archiving_throws_exce @pytest.mark.large @mock_sagemaker_aws_services -def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_from_local( +def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_and_env_from_local( pretrained_model, sagemaker_client ): app_name = "test-app" @@ -245,11 +245,18 @@ def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_from_loca assert app_name in [ endpoint["EndpointName"] for endpoint in sagemaker_client.list_endpoints()["Endpoints"] ] + model_environment = sagemaker_client.describe_model(ModelName=model_name)["PrimaryContainer"][ + "Environment" + ] + assert model_environment == { + "MLFLOW_DEPLOYMENT_FLAVOR_NAME": "python_function", + "SERVING_ENVIRONMENT": "SageMaker", + } @pytest.mark.large @mock_sagemaker_aws_services -def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_from_local( +def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_and_env_from_local( pretrained_model, sagemaker_client ): app_name = "test-app" @@ -288,11 +295,18 @@ def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_from_ assert app_name in [ endpoint["EndpointName"] for endpoint in sagemaker_client.list_endpoints()["Endpoints"] ] + model_environment = sagemaker_client.describe_model(ModelName=model_name)["PrimaryContainer"][ + "Environment" + ] + assert model_environment == { + "MLFLOW_DEPLOYMENT_FLAVOR_NAME": "python_function", + "SERVING_ENVIRONMENT": "SageMaker", + } @pytest.mark.large @mock_sagemaker_aws_services -def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_from_s3( +def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_and_env_from_s3( pretrained_model, sagemaker_client ): local_model_path = _download_artifact_from_uri(pretrained_model.model_uri) @@ -328,11 +342,18 @@ def test_deploy_creates_sagemaker_and_s3_resources_with_expected_names_from_s3( assert app_name in [ endpoint["EndpointName"] for endpoint in sagemaker_client.list_endpoints()["Endpoints"] ] + model_environment = sagemaker_client.describe_model(ModelName=model_name)["PrimaryContainer"][ + "Environment" + ] + assert model_environment == { + "MLFLOW_DEPLOYMENT_FLAVOR_NAME": "python_function", + "SERVING_ENVIRONMENT": "SageMaker", + } @pytest.mark.large @mock_sagemaker_aws_services -def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_from_s3( +def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_and_env_from_s3( pretrained_model, sagemaker_client ): local_model_path = _download_artifact_from_uri(pretrained_model.model_uri) @@ -373,6 +394,13 @@ def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_names_from_ assert app_name in [ endpoint["EndpointName"] for endpoint in sagemaker_client.list_endpoints()["Endpoints"] ] + model_environment = sagemaker_client.describe_model(ModelName=model_name)["PrimaryContainer"][ + "Environment" + ] + assert model_environment == { + "MLFLOW_DEPLOYMENT_FLAVOR_NAME": "python_function", + "SERVING_ENVIRONMENT": "SageMaker", + } @pytest.mark.large