Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
YHallouard committed Apr 21, 2024
1 parent 45f13c2 commit e614c1c
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 51 deletions.
38 changes: 37 additions & 1 deletion moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3755,16 +3755,52 @@ def __init__(
self.input_artifacts = input_artifacts if input_artifacts is not None else {}
self.output_artifacts = output_artifacts if output_artifacts is not None else {}
self.metadata_properties = metadata_properties
self.metrics: List[Dict[str, Union[float, datetime, str]]] = []
# self.metrics: List[Dict[str, Union[float, datetime, str]]] = []
self.metrics: Dict[str, Dict[str, Dict[str, Union[float, datetime, str]]]] = {}
self.sources: List[Dict[str, str]] = []

@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["Metrics"] = self.gen_metrics_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}

def gen_metrics_response_object(
self,
) -> List[Dict[str, Union[str, int, float, datetime]]]:
metrics_names = self.metrics.keys()
metrics_response_objects = []
for metrics_name in metrics_names:
max_step = max(list(self.metrics[metrics_name]["Values"].keys()))
metrics_values = list(
map(
lambda metric: metric["Value"],
list(self.metrics[metrics_name]["Values"].values()),
)
)
count = len(self.metrics[metrics_name]["Values"])
mean = sum(metrics_values) / count
std = (
sum(map(lambda value: (value - mean) ** 2, metrics_values)) / count
) ** 0.5
metrics_response_object = {
"MetricName": metrics_name,
"SourceArn": self.trial_component_arn,
"TimeStamp": datetime.fromtimestamp(
self.metrics[metrics_name]["Timestamp"]
).strftime("%Y-%m-%d %H:%M:%S"),
"Max": max(metrics_values),
"Min": min(metrics_values),
"Last": self.metrics[metrics_name]["Values"][max_step]["Value"],
"Count": count,
"Avg": mean,
"StdDev": std,
}
metrics_response_objects.append(metrics_response_object)
return metrics_response_objects

@property
def response_create(self) -> Dict[str, str]:
return {"TrialComponentArn": self.trial_component_arn}
Expand Down
2 changes: 1 addition & 1 deletion moto/sagemakermetrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .models import sagemakermetrics_backends #noqa: F401
from .models import sagemakermetrics_backends # noqa: F401
2 changes: 0 additions & 2 deletions moto/sagemakermetrics/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
"""Exceptions raised by the sagemakermetrics service."""
from moto.core.exceptions import JsonRESTError

37 changes: 26 additions & 11 deletions moto/sagemakermetrics/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""SageMakerMetricsBackend class with methods for supported APIs."""

from datetime import datetime
from typing import List, Dict, Union
from typing import Dict, List, Union

from moto.core.base_backend import BaseBackend, BackendDict
from moto.core.base_backend import BackendDict, BaseBackend
from moto.sagemaker import sagemaker_backends


Expand All @@ -14,17 +15,31 @@ def __init__(self, region_name, account_id):
self.sagemaker_backend = sagemaker_backends[account_id][region_name]

def batch_put_metrics(
self,
trial_component_name: str,
metric_data: List[Dict[str, Union[str, int, float, datetime]]],
self,
trial_component_name: str,
metric_data: List[Dict[str, Union[str, int, float, datetime]]],
):
return_response = {"Errors": []}

if trial_component_name not in self.sagemaker_backend.trial_components:
return {
"Errors": [{'Code': 'VALIDATION_ERROR', 'MetricIndex': 0}]
}
return_response["Errors"].append(
{"Code": "VALIDATION_ERROR", "MetricIndex": 0}
)
return return_response

trial_component = self.sagemaker_backend.trial_components[trial_component_name]
trial_component.metrics.extend(metric_data)
return {}

for metric in metric_data:
if metric["MetricName"] not in trial_component.metrics:
trial_component.metrics[metric["MetricName"]] = {
"MetricName": metric["MetricName"],
"Timestamp": metric["Timestamp"],
"Values": {},
}
trial_component.metrics[metric["MetricName"]]["Values"][metric["Step"]] = (
metric
)

return return_response


sagemakermetrics_backends = BackendDict(SageMakerMetricsBackend, "sagemaker-metrics")
4 changes: 3 additions & 1 deletion moto/sagemakermetrics/responses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Handles incoming sagemakermetrics requests, invokes methods, returns responses."""

import json

from moto.core.responses import BaseResponse

from .models import sagemakermetrics_backends


Expand All @@ -15,7 +17,7 @@ def __init__(self):
def sagemakermetrics_backend(self):
"""Return backend instance specific for this region."""
return sagemakermetrics_backends[self.current_account][self.region]

def batch_put_metrics(self):
trial_component_name = self._get_param("TrialComponentName")
metric_data = self._get_param("MetricData")
Expand Down
1 change: 1 addition & 0 deletions moto/sagemakermetrics/urls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""sagemakermetrics base URL and path."""

from .responses import SageMakerMetricsResponse

url_bases = [
Expand Down
82 changes: 48 additions & 34 deletions tests/test_sagemakermetrics/test_sagemakermetrics.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,72 @@
"""Unit tests for sagemakermetrics-supported APIs."""
from datetime import datetime

import datetime

import boto3

from moto import mock_aws

# @mock_aws

@mock_aws
def test_batch_put_metrics():
trial_component_name = "some-trial-component-name"
client_sagemaker = boto3.client(
"sagemaker",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name="eu-west-1"
)
# client_sagemaker.create_trial_component(TrialComponentName=trial_component_name)

client = boto3.client(
"sagemaker-metrics",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name="eu-west-1"
client_sagemaker = boto3.client("sagemaker", region_name="eu-west-1")
client_sagemaker.create_trial_component(TrialComponentName=trial_component_name)
describe_before_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
given_datetime = datetime.datetime(2024, 4, 21, 19, 33, 3)
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[{
'MetricName': 'some-metric-name',
'Timestamp': datetime(2015, 1, 1),
'Step': 123,
'Value': 123.0
},]
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": given_datetime,
"Step": 0,
"Value": 123.0,
},
],
)
describe_after_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

assert describe_before_metrics["Metrics"] == []
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert resp["Errors"] == []
assert describe_after_metrics["Metrics"][0]["MetricName"] == "some-metric-name"
assert (
describe_after_metrics["Metrics"][0]["SourceArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:experiment-trial-component/some-trial-component-name"
)
assert describe_after_metrics["Metrics"][0]["TimeStamp"] == datetime.datetime(
2024, 4, 21, 21, 33, 3
)
assert describe_after_metrics["Metrics"][0]["Max"] == 123.0
assert describe_after_metrics["Metrics"][0]["Min"] == 123.0
assert describe_after_metrics["Metrics"][0]["Last"] == 123.0
assert describe_after_metrics["Metrics"][0]["Count"] == 1
assert describe_after_metrics["Metrics"][0]["Avg"] == 123.0
assert describe_after_metrics["Metrics"][0]["StdDev"] == 0.0


@mock_aws
def test_batch_put_metrics_should_return_validation_error_if_trial_component_not_found():
trial_component_name = "some-trial-component-name-not-existing"
client = boto3.client(
"sagemaker-metrics",
region_name="eu-west-1"
)
client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[{
'MetricName': 'some-metric-name',
'Timestamp': datetime(2015, 1, 1),
'Step': 0,
'Value': 123.0,
}]
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": datetime.datetime(2015, 1, 1),
"Step": 0,
"Value": 123.0,
}
],
)

assert resp.get("Errors") is not None
assert resp["Errors"][0]["Code"] == "VALIDATION_ERROR"

2 changes: 1 addition & 1 deletion tests/test_sagemakermetrics/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def test_sagemakermetrics_batch_put_metrics():
resp = test_client.put("/BatchPutMetrics")

assert resp.status_code == 200
assert "VALIDATION_ERROR" in str(resp.data)
assert "VALIDATION_ERROR" in str(resp.data)

0 comments on commit e614c1c

Please sign in to comment.