-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sagemaker-Metrics: Add sagemaker metrics (#7617)
- Loading branch information
1 parent
c2bcd13
commit a9866b0
Showing
14 changed files
with
262 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
.. _implementedservice_sagemaker-metrics: | ||
|
||
.. |start-h3| raw:: html | ||
|
||
<h3> | ||
|
||
.. |end-h3| raw:: html | ||
|
||
</h3> | ||
|
||
================= | ||
sagemaker-metrics | ||
================= | ||
|
||
.. autoclass:: moto.sagemakermetrics.models.SageMakerMetricsBackend | ||
|
||
|start-h3| Implemented features for this service |end-h3| | ||
|
||
- [X] batch_put_metrics | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .models import sagemakermetrics_backends # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Exceptions raised by the sagemakermetrics service.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""SageMakerMetricsBackend class with methods for supported APIs.""" | ||
|
||
from datetime import datetime | ||
from typing import Dict, List, Union, cast | ||
|
||
from moto.core.base_backend import BackendDict, BaseBackend | ||
from moto.sagemaker import sagemaker_backends | ||
from moto.sagemaker.models import METRIC_STEP_TYPE | ||
|
||
RESPONSE_TYPE = Dict[str, List[Dict[str, Union[str, int]]]] | ||
|
||
|
||
class SageMakerMetricsBackend(BaseBackend): | ||
"""Implementation of SageMakerMetrics APIs.""" | ||
|
||
def __init__(self, region_name: str, account_id: str) -> None: | ||
super().__init__(region_name, account_id) | ||
self.sagemaker_backend = sagemaker_backends[account_id][region_name] | ||
|
||
def batch_put_metrics( | ||
self, | ||
trial_component_name: str, | ||
metric_data: List[Dict[str, Union[str, int, float, datetime]]], | ||
) -> RESPONSE_TYPE: | ||
return_response: RESPONSE_TYPE = {"Errors": []} | ||
|
||
if trial_component_name not in self.sagemaker_backend.trial_components: | ||
return_response["Errors"].append( | ||
{"Code": "VALIDATION_ERROR", "MetricIndex": 0} | ||
) | ||
return return_response | ||
|
||
trial_component = self.sagemaker_backend.trial_components[trial_component_name] | ||
for metric in metric_data: | ||
metric_step: int = cast(int, metric["Step"]) | ||
metric_name: str = cast(str, metric["MetricName"]) | ||
if metric_name not in trial_component.metrics: | ||
metric_timestamp: int = cast(int, metric["Timestamp"]) | ||
values_dict: Dict[int, Dict[str, Union[str, int, float, datetime]]] = {} | ||
new_metric: Dict[str, Union[str, int, METRIC_STEP_TYPE]] = { | ||
"MetricName": metric_name, | ||
"Timestamp": metric_timestamp, | ||
"Values": values_dict, | ||
} | ||
trial_component.metrics[metric_name] = new_metric | ||
new_step: METRIC_STEP_TYPE = {metric_step: metric} | ||
trial_component_metric_values: METRIC_STEP_TYPE = cast( | ||
METRIC_STEP_TYPE, trial_component.metrics[metric_name]["Values"] | ||
) | ||
trial_component_metric_values.update(new_step) # type ignore | ||
return return_response | ||
|
||
|
||
sagemakermetrics_backends = BackendDict(SageMakerMetricsBackend, "sagemaker-metrics") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Handles incoming sagemakermetrics requests, invokes methods, returns responses.""" | ||
|
||
import json | ||
|
||
from moto.core.responses import BaseResponse | ||
|
||
from .models import SageMakerMetricsBackend, sagemakermetrics_backends | ||
|
||
|
||
class SageMakerMetricsResponse(BaseResponse): | ||
"""Handler for SageMakerMetrics requests and responses.""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__(service_name="sagemaker-metrics") | ||
|
||
@property | ||
def sagemakermetrics_backend(self) -> SageMakerMetricsBackend: | ||
"""Return backend instance specific for this region.""" | ||
return sagemakermetrics_backends[self.current_account][self.region] | ||
|
||
def batch_put_metrics(self) -> str: | ||
trial_component_name = self._get_param("TrialComponentName") | ||
metric_data = self._get_param("MetricData") | ||
errors = self.sagemakermetrics_backend.batch_put_metrics( | ||
trial_component_name=trial_component_name, | ||
metric_data=metric_data, | ||
) | ||
return json.dumps(errors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""sagemakermetrics base URL and path.""" | ||
|
||
from .responses import SageMakerMetricsResponse | ||
|
||
url_bases = [ | ||
r"https?://metrics.sagemaker\.(.+)\.amazonaws\.com", | ||
] | ||
|
||
url_paths = { | ||
"{0}/$": SageMakerMetricsResponse.dispatch, | ||
"{0}/BatchPutMetrics$": SageMakerMetricsResponse.dispatch, | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Unit tests for sagemakermetrics-supported APIs.""" | ||
|
||
import datetime | ||
|
||
import boto3 | ||
|
||
from moto import mock_aws | ||
|
||
|
||
@mock_aws | ||
def test_batch_put_metrics(): | ||
trial_component_name = "some-trial-component-name" | ||
client_sagemaker = boto3.client("sagemaker", region_name="eu-west-1") | ||
client_sagemaker.create_trial_component(TrialComponentName=trial_component_name) | ||
describe_before_metrics = client_sagemaker.describe_trial_component( | ||
TrialComponentName=trial_component_name | ||
) | ||
|
||
client = boto3.client("sagemaker-metrics", region_name="eu-west-1") | ||
given_datetime = datetime.datetime(2024, 4, 21, 0, 0, 0) | ||
resp = client.batch_put_metrics( | ||
TrialComponentName=trial_component_name, | ||
MetricData=[ | ||
{ | ||
"MetricName": "some-metric-name", | ||
"Timestamp": given_datetime, | ||
"Step": 0, | ||
"Value": 123.0, | ||
}, | ||
], | ||
) | ||
describe_after_metrics = client_sagemaker.describe_trial_component( | ||
TrialComponentName=trial_component_name | ||
) | ||
|
||
assert describe_before_metrics["Metrics"] == [] | ||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | ||
assert resp["Errors"] == [] | ||
assert describe_after_metrics["Metrics"][0]["MetricName"] == "some-metric-name" | ||
assert ( | ||
describe_after_metrics["Metrics"][0]["SourceArn"] | ||
== "arn:aws:sagemaker:eu-west-1:123456789012:experiment-trial-component/some-trial-component-name" | ||
) | ||
assert describe_after_metrics["Metrics"][0]["TimeStamp"] == given_datetime | ||
assert describe_after_metrics["Metrics"][0]["Max"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Min"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Last"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Count"] == 1 | ||
assert describe_after_metrics["Metrics"][0]["Avg"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["StdDev"] == 0.0 | ||
|
||
|
||
@mock_aws | ||
def test_batch_put_metrics_should_return_validation_error_if_trial_component_not_found(): | ||
trial_component_name = "some-trial-component-name-not-existing" | ||
client = boto3.client("sagemaker-metrics", region_name="eu-west-1") | ||
resp = client.batch_put_metrics( | ||
TrialComponentName=trial_component_name, | ||
MetricData=[ | ||
{ | ||
"MetricName": "some-metric-name", | ||
"Timestamp": datetime.datetime(2015, 1, 1), | ||
"Step": 0, | ||
"Value": 123.0, | ||
} | ||
], | ||
) | ||
|
||
assert resp.get("Errors") is not None | ||
assert resp["Errors"][0]["Code"] == "VALIDATION_ERROR" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Test different server responses.""" | ||
|
||
import moto.server as server | ||
|
||
|
||
def test_sagemakermetrics_batch_put_metrics(): | ||
backend = server.create_backend_app("sagemaker-metrics") | ||
test_client = backend.test_client() | ||
|
||
resp = test_client.put("/BatchPutMetrics") | ||
|
||
assert resp.status_code == 200 | ||
assert "VALIDATION_ERROR" in str(resp.data) |