Skip to content

Commit

Permalink
Sagemaker-Metrics: Add sagemaker metrics (#7617)
Browse files Browse the repository at this point in the history
  • Loading branch information
YHallouard committed Apr 29, 2024
1 parent c2bcd13 commit a9866b0
Show file tree
Hide file tree
Showing 14 changed files with 262 additions and 2 deletions.
8 changes: 7 additions & 1 deletion IMPLEMENTATION_COVERAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7211,6 +7211,13 @@
- [ ] update_workteam
</details>

## sagemaker-metrics
<details>
<summary>100% implemented</summary>

- [X] batch_put_metrics
</details>

## sagemaker-runtime
<details>
<summary>66% implemented</summary>
Expand Down Expand Up @@ -8429,7 +8436,6 @@
- sagemaker-edge
- sagemaker-featurestore-runtime
- sagemaker-geospatial
- sagemaker-metrics
- savingsplans
- schemas
- securityhub
Expand Down
20 changes: 20 additions & 0 deletions docs/docs/services/sagemaker-metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. _implementedservice_sagemaker-metrics:

.. |start-h3| raw:: html

<h3>

.. |end-h3| raw:: html

</h3>

=================
sagemaker-metrics
=================

.. autoclass:: moto.sagemakermetrics.models.SageMakerMetricsBackend

|start-h3| Implemented features for this service |end-h3|

- [X] batch_put_metrics

4 changes: 4 additions & 0 deletions moto/backend_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@
re.compile("https?://([0-9]+)\\.s3-control\\.(.+)\\.amazonaws\\.com"),
),
("sagemaker", re.compile("https?://api\\.sagemaker\\.(.+)\\.amazonaws.com")),
(
"sagemakermetrics",
re.compile("https?://metrics.sagemaker\\.(.+)\\.amazonaws\\.com"),
),
(
"sagemakerruntime",
re.compile("https?://runtime\\.sagemaker\\.(.+)\\.amazonaws\\.com"),
Expand Down
6 changes: 6 additions & 0 deletions moto/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from moto.s3.models import S3Backend
from moto.s3control.models import S3ControlBackend
from moto.sagemaker.models import SageMakerModelBackend
from moto.sagemakermetrics.models import SageMakerMetricsBackend
from moto.sagemakerruntime.models import SageMakerRuntimeBackend
from moto.scheduler.models import EventBridgeSchedulerBackend
from moto.sdb.models import SimpleDBBackend
Expand Down Expand Up @@ -279,6 +280,7 @@ def get_service_from_url(url: str) -> Optional[str]:
"Literal['s3bucket_path']",
"Literal['s3control']",
"Literal['sagemaker']",
"Literal['sagemaker-metrics']",
"Literal['sagemaker-runtime']",
"Literal['scheduler']",
"Literal['sdb']",
Expand Down Expand Up @@ -614,6 +616,10 @@ def get_backend(
name: "Literal['sagemaker']",
) -> "BackendDict[SageMakerModelBackend]": ...
@overload
def get_backend(
name: "Literal['sagemaker-metrics']",
) -> "BackendDict[SageMakerMetricsBackend]": ...
@overload
def get_backend(
name: "Literal['sagemaker-runtime']",
) -> "BackendDict[SageMakerRuntimeBackend]": ...
Expand Down
2 changes: 2 additions & 0 deletions moto/moto_server/werkzeug_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def infer_service_region_host(
elif service == "sagemaker":
if environ["PATH_INFO"].endswith("invocations"):
host = f"runtime.{service}.{region}.amazonaws.com"
elif environ["PATH_INFO"].endswith("BatchPutMetrics"):
host = f"metrics.{service}.{region}.amazonaws.com"
else:
host = f"api.{service}.{region}.amazonaws.com"
elif service == "timestream":
Expand Down
45 changes: 44 additions & 1 deletion moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
},
}

METRIC_INFO_TYPE = Dict[str, Union[str, int, float, datetime]]
METRIC_STEP_TYPE = Dict[int, METRIC_INFO_TYPE]


class BaseObject(BaseModel):
def camelCase(self, key: str) -> str:
Expand Down Expand Up @@ -3755,16 +3758,56 @@ def __init__(
self.input_artifacts = input_artifacts if input_artifacts is not None else {}
self.output_artifacts = output_artifacts if output_artifacts is not None else {}
self.metadata_properties = metadata_properties
self.metrics: List[Dict[str, Union[float, datetime, str]]] = []
self.metrics: Dict[str, Dict[str, Union[str, int, METRIC_STEP_TYPE]]] = {}
self.sources: List[Dict[str, str]] = []

@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["Metrics"] = self.gen_metrics_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}

def gen_metrics_response_object(
self,
) -> List[Dict[str, Union[str, int, float, datetime]]]:
metrics_names = self.metrics.keys()
metrics_response_objects = []
for metrics_name in metrics_names:
metrics_steps: METRIC_STEP_TYPE = cast(
METRIC_STEP_TYPE, self.metrics[metrics_name]["Values"]
)
max_step = max(list(metrics_steps.keys()))
metrics_steps_values: List[float] = list(
map(
lambda metric: cast(float, metric["Value"]),
list(metrics_steps.values()),
)
)
count = len(metrics_steps_values)
mean = sum(metrics_steps_values) / count
std = (
sum(map(lambda value: (value - mean) ** 2, metrics_steps_values))
/ count
) ** 0.5
timestamp_int: int = cast(int, self.metrics[metrics_name]["Timestamp"])
metrics_response_object = {
"MetricName": metrics_name,
"SourceArn": self.trial_component_arn,
"TimeStamp": datetime.fromtimestamp(timestamp_int, tz=tzutc()).strftime(
"%Y-%m-%d %H:%M:%S"
),
"Max": max(metrics_steps_values),
"Min": min(metrics_steps_values),
"Last": metrics_steps[max_step]["Value"],
"Count": count,
"Avg": mean,
"StdDev": std,
}
metrics_response_objects.append(metrics_response_object)
return metrics_response_objects

@property
def response_create(self) -> Dict[str, str]:
return {"TrialComponentArn": self.trial_component_arn}
Expand Down
1 change: 1 addition & 0 deletions moto/sagemakermetrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .models import sagemakermetrics_backends # noqa: F401
1 change: 1 addition & 0 deletions moto/sagemakermetrics/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Exceptions raised by the sagemakermetrics service."""
54 changes: 54 additions & 0 deletions moto/sagemakermetrics/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""SageMakerMetricsBackend class with methods for supported APIs."""

from datetime import datetime
from typing import Dict, List, Union, cast

from moto.core.base_backend import BackendDict, BaseBackend
from moto.sagemaker import sagemaker_backends
from moto.sagemaker.models import METRIC_STEP_TYPE

RESPONSE_TYPE = Dict[str, List[Dict[str, Union[str, int]]]]


class SageMakerMetricsBackend(BaseBackend):
"""Implementation of SageMakerMetrics APIs."""

def __init__(self, region_name: str, account_id: str) -> None:
super().__init__(region_name, account_id)
self.sagemaker_backend = sagemaker_backends[account_id][region_name]

def batch_put_metrics(
self,
trial_component_name: str,
metric_data: List[Dict[str, Union[str, int, float, datetime]]],
) -> RESPONSE_TYPE:
return_response: RESPONSE_TYPE = {"Errors": []}

if trial_component_name not in self.sagemaker_backend.trial_components:
return_response["Errors"].append(
{"Code": "VALIDATION_ERROR", "MetricIndex": 0}
)
return return_response

trial_component = self.sagemaker_backend.trial_components[trial_component_name]
for metric in metric_data:
metric_step: int = cast(int, metric["Step"])
metric_name: str = cast(str, metric["MetricName"])
if metric_name not in trial_component.metrics:
metric_timestamp: int = cast(int, metric["Timestamp"])
values_dict: Dict[int, Dict[str, Union[str, int, float, datetime]]] = {}
new_metric: Dict[str, Union[str, int, METRIC_STEP_TYPE]] = {
"MetricName": metric_name,
"Timestamp": metric_timestamp,
"Values": values_dict,
}
trial_component.metrics[metric_name] = new_metric
new_step: METRIC_STEP_TYPE = {metric_step: metric}
trial_component_metric_values: METRIC_STEP_TYPE = cast(
METRIC_STEP_TYPE, trial_component.metrics[metric_name]["Values"]
)
trial_component_metric_values.update(new_step) # type ignore
return return_response


sagemakermetrics_backends = BackendDict(SageMakerMetricsBackend, "sagemaker-metrics")
28 changes: 28 additions & 0 deletions moto/sagemakermetrics/responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Handles incoming sagemakermetrics requests, invokes methods, returns responses."""

import json

from moto.core.responses import BaseResponse

from .models import SageMakerMetricsBackend, sagemakermetrics_backends


class SageMakerMetricsResponse(BaseResponse):
"""Handler for SageMakerMetrics requests and responses."""

def __init__(self) -> None:
super().__init__(service_name="sagemaker-metrics")

@property
def sagemakermetrics_backend(self) -> SageMakerMetricsBackend:
"""Return backend instance specific for this region."""
return sagemakermetrics_backends[self.current_account][self.region]

def batch_put_metrics(self) -> str:
trial_component_name = self._get_param("TrialComponentName")
metric_data = self._get_param("MetricData")
errors = self.sagemakermetrics_backend.batch_put_metrics(
trial_component_name=trial_component_name,
metric_data=metric_data,
)
return json.dumps(errors)
12 changes: 12 additions & 0 deletions moto/sagemakermetrics/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""sagemakermetrics base URL and path."""

from .responses import SageMakerMetricsResponse

url_bases = [
r"https?://metrics.sagemaker\.(.+)\.amazonaws\.com",
]

url_paths = {
"{0}/$": SageMakerMetricsResponse.dispatch,
"{0}/BatchPutMetrics$": SageMakerMetricsResponse.dispatch,
}
Empty file.
70 changes: 70 additions & 0 deletions tests/test_sagemakermetrics/test_sagemakermetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit tests for sagemakermetrics-supported APIs."""

import datetime

import boto3

from moto import mock_aws


@mock_aws
def test_batch_put_metrics():
trial_component_name = "some-trial-component-name"
client_sagemaker = boto3.client("sagemaker", region_name="eu-west-1")
client_sagemaker.create_trial_component(TrialComponentName=trial_component_name)
describe_before_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
given_datetime = datetime.datetime(2024, 4, 21, 0, 0, 0)
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": given_datetime,
"Step": 0,
"Value": 123.0,
},
],
)
describe_after_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

assert describe_before_metrics["Metrics"] == []
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert resp["Errors"] == []
assert describe_after_metrics["Metrics"][0]["MetricName"] == "some-metric-name"
assert (
describe_after_metrics["Metrics"][0]["SourceArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:experiment-trial-component/some-trial-component-name"
)
assert describe_after_metrics["Metrics"][0]["TimeStamp"] == given_datetime
assert describe_after_metrics["Metrics"][0]["Max"] == 123.0
assert describe_after_metrics["Metrics"][0]["Min"] == 123.0
assert describe_after_metrics["Metrics"][0]["Last"] == 123.0
assert describe_after_metrics["Metrics"][0]["Count"] == 1
assert describe_after_metrics["Metrics"][0]["Avg"] == 123.0
assert describe_after_metrics["Metrics"][0]["StdDev"] == 0.0


@mock_aws
def test_batch_put_metrics_should_return_validation_error_if_trial_component_not_found():
trial_component_name = "some-trial-component-name-not-existing"
client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": datetime.datetime(2015, 1, 1),
"Step": 0,
"Value": 123.0,
}
],
)

assert resp.get("Errors") is not None
assert resp["Errors"][0]["Code"] == "VALIDATION_ERROR"
13 changes: 13 additions & 0 deletions tests/test_sagemakermetrics/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Test different server responses."""

import moto.server as server


def test_sagemakermetrics_batch_put_metrics():
backend = server.create_backend_app("sagemaker-metrics")
test_client = backend.test_client()

resp = test_client.put("/BatchPutMetrics")

assert resp.status_code == 200
assert "VALIDATION_ERROR" in str(resp.data)

0 comments on commit a9866b0

Please sign in to comment.