-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ebf8799
commit 4e445d7
Showing
8 changed files
with
139 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .models import sagemakermetrics_backends #noqa: F401 | ||
from .models import sagemakermetrics_backends # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1 @@ | ||
"""Exceptions raised by the sagemakermetrics service.""" | ||
from moto.core.exceptions import JsonRESTError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,54 @@ | ||
"""SageMakerMetricsBackend class with methods for supported APIs.""" | ||
|
||
from datetime import datetime | ||
from typing import List, Dict, Union | ||
from typing import Dict, List, Union, cast | ||
|
||
from moto.core.base_backend import BaseBackend, BackendDict | ||
from moto.core.base_backend import BackendDict, BaseBackend | ||
from moto.sagemaker import sagemaker_backends | ||
from moto.sagemaker.models import METRIC_STEP_TYPE | ||
|
||
RESPONSE_TYPE = Dict[str, List[Dict[str, Union[str, int]]]] | ||
|
||
|
||
class SageMakerMetricsBackend(BaseBackend): | ||
"""Implementation of SageMakerMetrics APIs.""" | ||
|
||
def __init__(self, region_name, account_id): | ||
def __init__(self, region_name: str, account_id: str) -> None: | ||
super().__init__(region_name, account_id) | ||
self.sagemaker_backend = sagemaker_backends[account_id][region_name] | ||
|
||
def batch_put_metrics( | ||
self, | ||
trial_component_name: str, | ||
metric_data: List[Dict[str, Union[str, int, float, datetime]]], | ||
): | ||
self, | ||
trial_component_name: str, | ||
metric_data: List[Dict[str, Union[str, int, float, datetime]]], | ||
) -> RESPONSE_TYPE: | ||
return_response: RESPONSE_TYPE = {"Errors": []} | ||
|
||
if trial_component_name not in self.sagemaker_backend.trial_components: | ||
return { | ||
"Errors": [{'Code': 'VALIDATION_ERROR', 'MetricIndex': 0}] | ||
} | ||
return_response["Errors"].append( | ||
{"Code": "VALIDATION_ERROR", "MetricIndex": 0} | ||
) | ||
return return_response | ||
|
||
trial_component = self.sagemaker_backend.trial_components[trial_component_name] | ||
trial_component.metrics.extend(metric_data) | ||
return {} | ||
|
||
for metric in metric_data: | ||
metric_step: int = cast(int, metric["Step"]) | ||
metric_name: str = cast(str, metric["MetricName"]) | ||
if metric_name not in trial_component.metrics: | ||
metric_timestamp: int = cast(int, metric["Timestamp"]) | ||
values_dict: Dict[int, Dict[str, Union[str, int, float, datetime]]] = {} | ||
new_metric: Dict[str, Union[str, int, METRIC_STEP_TYPE]] = { | ||
"MetricName": metric_name, | ||
"Timestamp": metric_timestamp, | ||
"Values": values_dict, | ||
} | ||
trial_component.metrics[metric_name] = new_metric | ||
new_step: METRIC_STEP_TYPE = {metric_step: metric} | ||
trial_component_metric_values: METRIC_STEP_TYPE = cast( | ||
METRIC_STEP_TYPE, trial_component.metrics[metric_name]["Values"] | ||
) | ||
trial_component_metric_values.update(new_step) # type ignore | ||
return return_response | ||
|
||
|
||
sagemakermetrics_backends = BackendDict(SageMakerMetricsBackend, "sagemaker-metrics") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""sagemakermetrics base URL and path.""" | ||
|
||
from .responses import SageMakerMetricsResponse | ||
|
||
url_bases = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,72 @@ | ||
"""Unit tests for sagemakermetrics-supported APIs.""" | ||
from datetime import datetime | ||
|
||
import datetime | ||
|
||
import boto3 | ||
|
||
from moto import mock_aws | ||
|
||
# @mock_aws | ||
|
||
@mock_aws | ||
def test_batch_put_metrics(): | ||
trial_component_name = "some-trial-component-name" | ||
client_sagemaker = boto3.client( | ||
"sagemaker", | ||
aws_access_key_id=aws_access_key_id, | ||
aws_secret_access_key=aws_secret_access_key, | ||
aws_session_token=aws_session_token, | ||
region_name="eu-west-1" | ||
) | ||
# client_sagemaker.create_trial_component(TrialComponentName=trial_component_name) | ||
|
||
client = boto3.client( | ||
"sagemaker-metrics", | ||
aws_access_key_id=aws_access_key_id, | ||
aws_secret_access_key=aws_secret_access_key, | ||
aws_session_token=aws_session_token, | ||
region_name="eu-west-1" | ||
client_sagemaker = boto3.client("sagemaker", region_name="eu-west-1") | ||
client_sagemaker.create_trial_component(TrialComponentName=trial_component_name) | ||
describe_before_metrics = client_sagemaker.describe_trial_component( | ||
TrialComponentName=trial_component_name | ||
) | ||
|
||
client = boto3.client("sagemaker-metrics", region_name="eu-west-1") | ||
given_datetime = datetime.datetime(2024, 4, 21, 19, 33, 3) | ||
resp = client.batch_put_metrics( | ||
TrialComponentName=trial_component_name, | ||
MetricData=[{ | ||
'MetricName': 'some-metric-name', | ||
'Timestamp': datetime(2015, 1, 1), | ||
'Step': 123, | ||
'Value': 123.0 | ||
},] | ||
MetricData=[ | ||
{ | ||
"MetricName": "some-metric-name", | ||
"Timestamp": given_datetime, | ||
"Step": 0, | ||
"Value": 123.0, | ||
}, | ||
], | ||
) | ||
describe_after_metrics = client_sagemaker.describe_trial_component( | ||
TrialComponentName=trial_component_name | ||
) | ||
|
||
assert describe_before_metrics["Metrics"] == [] | ||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | ||
assert resp["Errors"] == [] | ||
assert describe_after_metrics["Metrics"][0]["MetricName"] == "some-metric-name" | ||
assert ( | ||
describe_after_metrics["Metrics"][0]["SourceArn"] | ||
== "arn:aws:sagemaker:eu-west-1:123456789012:experiment-trial-component/some-trial-component-name" | ||
) | ||
assert describe_after_metrics["Metrics"][0]["TimeStamp"] == datetime.datetime( | ||
Check failure on line 44 in tests/test_sagemakermetrics/test_sagemakermetrics.py GitHub Actions / test / test (3.8)
Check failure on line 44 in tests/test_sagemakermetrics/test_sagemakermetrics.py GitHub Actions / test / test (3.9)
Check failure on line 44 in tests/test_sagemakermetrics/test_sagemakermetrics.py GitHub Actions / test / test (3.10)
Check failure on line 44 in tests/test_sagemakermetrics/test_sagemakermetrics.py GitHub Actions / test / test (3.11)
Check failure on line 44 in tests/test_sagemakermetrics/test_sagemakermetrics.py GitHub Actions / test / test (3.12)
|
||
2024, 4, 21, 21, 33, 3 | ||
) | ||
assert describe_after_metrics["Metrics"][0]["Max"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Min"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Last"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["Count"] == 1 | ||
assert describe_after_metrics["Metrics"][0]["Avg"] == 123.0 | ||
assert describe_after_metrics["Metrics"][0]["StdDev"] == 0.0 | ||
|
||
|
||
@mock_aws | ||
def test_batch_put_metrics_should_return_validation_error_if_trial_component_not_found(): | ||
trial_component_name = "some-trial-component-name-not-existing" | ||
client = boto3.client( | ||
"sagemaker-metrics", | ||
region_name="eu-west-1" | ||
) | ||
client = boto3.client("sagemaker-metrics", region_name="eu-west-1") | ||
resp = client.batch_put_metrics( | ||
TrialComponentName=trial_component_name, | ||
MetricData=[{ | ||
'MetricName': 'some-metric-name', | ||
'Timestamp': datetime(2015, 1, 1), | ||
'Step': 0, | ||
'Value': 123.0, | ||
}] | ||
MetricData=[ | ||
{ | ||
"MetricName": "some-metric-name", | ||
"Timestamp": datetime.datetime(2015, 1, 1), | ||
"Step": 0, | ||
"Value": 123.0, | ||
} | ||
], | ||
) | ||
|
||
assert resp.get("Errors") is not None | ||
assert resp["Errors"][0]["Code"] == "VALIDATION_ERROR" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters