Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
YHallouard committed Apr 22, 2024
1 parent 45f13c2 commit e5cf1af
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 57 deletions.
42 changes: 41 additions & 1 deletion moto/sagemaker/models.py
Expand Up @@ -73,6 +73,9 @@
},
}

METRIC_INFO_TYPE = Dict[str, Union[str, int, float, datetime]]
METRIC_STEP_TYPE = Dict[int, METRIC_INFO_TYPE]


class BaseObject(BaseModel):
def camelCase(self, key: str) -> str:
Expand Down Expand Up @@ -3755,16 +3758,53 @@ def __init__(
self.input_artifacts = input_artifacts if input_artifacts is not None else {}
self.output_artifacts = output_artifacts if output_artifacts is not None else {}
self.metadata_properties = metadata_properties
self.metrics: List[Dict[str, Union[float, datetime, str]]] = []
self.metrics: Dict[str, Dict[str, Union[str, int, METRIC_STEP_TYPE]]] = {}
self.sources: List[Dict[str, str]] = []

@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["Metrics"] = self.gen_metrics_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}

def gen_metrics_response_object(
self,
) -> List[Dict[str, Union[str, int, float, datetime]]]:
metrics_names = self.metrics.keys()
metrics_response_objects = []
for metrics_name in metrics_names:
metrics_steps: METRIC_STEP_TYPE = cast(METRIC_STEP_TYPE, self.metrics[metrics_name]["Values"])
max_step = max(list(metrics_steps.keys()))
metrics_steps_values: List[float] = list(
map(
lambda metric: cast(float, metric["Value"]),
list(metrics_steps.values()),
)
)
count = len(metrics_steps_values)
mean = sum(metrics_steps_values) / count
std = (
sum(map(lambda value: (value - mean) ** 2, metrics_steps_values)) / count
) ** 0.5
timestamp_int: int = cast(int, self.metrics[metrics_name]["Timestamp"])
metrics_response_object = {
"MetricName": metrics_name,
"SourceArn": self.trial_component_arn,
"TimeStamp": datetime.fromtimestamp(timestamp_int).strftime(
"%Y-%m-%d %H:%M:%S"
),
"Max": max(metrics_steps_values),
"Min": min(metrics_steps_values),
"Last": metrics_steps[max_step]["Value"],
"Count": count,
"Avg": mean,
"StdDev": std,
}
metrics_response_objects.append(metrics_response_object)
return metrics_response_objects

@property
def response_create(self) -> Dict[str, str]:
return {"TrialComponentArn": self.trial_component_arn}
Expand Down
2 changes: 1 addition & 1 deletion moto/sagemakermetrics/__init__.py
@@ -1 +1 @@
from .models import sagemakermetrics_backends #noqa: F401
from .models import sagemakermetrics_backends # noqa: F401
2 changes: 0 additions & 2 deletions moto/sagemakermetrics/exceptions.py
@@ -1,3 +1 @@
"""Exceptions raised by the sagemakermetrics service."""
from moto.core.exceptions import JsonRESTError

50 changes: 37 additions & 13 deletions moto/sagemakermetrics/models.py
@@ -1,30 +1,54 @@
"""SageMakerMetricsBackend class with methods for supported APIs."""

from datetime import datetime
from typing import List, Dict, Union
from typing import Dict, List, Union, cast

from moto.core.base_backend import BaseBackend, BackendDict
from moto.core.base_backend import BackendDict, BaseBackend
from moto.sagemaker import sagemaker_backends
from moto.sagemaker.models import METRIC_STEP_TYPE

RESPONSE_TYPE = Dict[str, List[Dict[str, Union[str, int]]]]


class SageMakerMetricsBackend(BaseBackend):
"""Implementation of SageMakerMetrics APIs."""

def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str) -> None:
super().__init__(region_name, account_id)
self.sagemaker_backend = sagemaker_backends[account_id][region_name]

def batch_put_metrics(
self,
trial_component_name: str,
metric_data: List[Dict[str, Union[str, int, float, datetime]]],
):
self,
trial_component_name: str,
metric_data: List[Dict[str, Union[str, int, float, datetime]]],
) -> RESPONSE_TYPE:
return_response: RESPONSE_TYPE = {"Errors": []}

if trial_component_name not in self.sagemaker_backend.trial_components:
return {
"Errors": [{'Code': 'VALIDATION_ERROR', 'MetricIndex': 0}]
}
return_response["Errors"].append(
{"Code": "VALIDATION_ERROR", "MetricIndex": 0}
)
return return_response

trial_component = self.sagemaker_backend.trial_components[trial_component_name]
trial_component.metrics.extend(metric_data)
return {}

for metric in metric_data:
metric_step: int = cast(int, metric["Step"])
metric_name: str = cast(str, metric["MetricName"])
if metric_name not in trial_component.metrics:
metric_timestamp: int = cast(int, metric["Timestamp"])
values_dict: Dict[int, Dict[str, Union[str, int, float, datetime]]] = {}
new_metric: Dict[str, Union[str, int, METRIC_STEP_TYPE]] = {
"MetricName": metric_name,
"Timestamp": metric_timestamp,
"Values": values_dict,
}
trial_component.metrics[metric_name] = new_metric
new_step: METRIC_STEP_TYPE = {metric_step: metric}
trial_component_metric_values: METRIC_STEP_TYPE = cast(
METRIC_STEP_TYPE, trial_component.metrics[metric_name]["Values"]
)
trial_component_metric_values.update(new_step) # type ignore
return return_response


sagemakermetrics_backends = BackendDict(SageMakerMetricsBackend, "sagemaker-metrics")
12 changes: 7 additions & 5 deletions moto/sagemakermetrics/responses.py
@@ -1,22 +1,24 @@
"""Handles incoming sagemakermetrics requests, invokes methods, returns responses."""

import json

from moto.core.responses import BaseResponse
from .models import sagemakermetrics_backends

from .models import SageMakerMetricsBackend, sagemakermetrics_backends


class SageMakerMetricsResponse(BaseResponse):
"""Handler for SageMakerMetrics requests and responses."""

def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="sagemaker-metrics")

@property
def sagemakermetrics_backend(self):
def sagemakermetrics_backend(self) -> SageMakerMetricsBackend:
"""Return backend instance specific for this region."""
return sagemakermetrics_backends[self.current_account][self.region]
def batch_put_metrics(self):

def batch_put_metrics(self) -> str:
trial_component_name = self._get_param("TrialComponentName")
metric_data = self._get_param("MetricData")
errors = self.sagemakermetrics_backend.batch_put_metrics(
Expand Down
1 change: 1 addition & 0 deletions moto/sagemakermetrics/urls.py
@@ -1,4 +1,5 @@
"""sagemakermetrics base URL and path."""

from .responses import SageMakerMetricsResponse

url_bases = [
Expand Down
82 changes: 48 additions & 34 deletions tests/test_sagemakermetrics/test_sagemakermetrics.py
@@ -1,58 +1,72 @@
"""Unit tests for sagemakermetrics-supported APIs."""
from datetime import datetime

import datetime

import boto3

from moto import mock_aws

# @mock_aws

@mock_aws
def test_batch_put_metrics():
trial_component_name = "some-trial-component-name"
client_sagemaker = boto3.client(
"sagemaker",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name="eu-west-1"
)
# client_sagemaker.create_trial_component(TrialComponentName=trial_component_name)

client = boto3.client(
"sagemaker-metrics",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name="eu-west-1"
client_sagemaker = boto3.client("sagemaker", region_name="eu-west-1")
client_sagemaker.create_trial_component(TrialComponentName=trial_component_name)
describe_before_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
given_datetime = datetime.datetime(2024, 4, 21, 19, 33, 3)
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[{
'MetricName': 'some-metric-name',
'Timestamp': datetime(2015, 1, 1),
'Step': 123,
'Value': 123.0
},]
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": given_datetime,
"Step": 0,
"Value": 123.0,
},
],
)
describe_after_metrics = client_sagemaker.describe_trial_component(
TrialComponentName=trial_component_name
)

assert describe_before_metrics["Metrics"] == []
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert resp["Errors"] == []
assert describe_after_metrics["Metrics"][0]["MetricName"] == "some-metric-name"
assert (
describe_after_metrics["Metrics"][0]["SourceArn"]
== "arn:aws:sagemaker:eu-west-1:123456789012:experiment-trial-component/some-trial-component-name"
)
assert describe_after_metrics["Metrics"][0]["TimeStamp"] == datetime.datetime(
2024, 4, 21, 21, 33, 3
)
assert describe_after_metrics["Metrics"][0]["Max"] == 123.0
assert describe_after_metrics["Metrics"][0]["Min"] == 123.0
assert describe_after_metrics["Metrics"][0]["Last"] == 123.0
assert describe_after_metrics["Metrics"][0]["Count"] == 1
assert describe_after_metrics["Metrics"][0]["Avg"] == 123.0
assert describe_after_metrics["Metrics"][0]["StdDev"] == 0.0


@mock_aws
def test_batch_put_metrics_should_return_validation_error_if_trial_component_not_found():
trial_component_name = "some-trial-component-name-not-existing"
client = boto3.client(
"sagemaker-metrics",
region_name="eu-west-1"
)
client = boto3.client("sagemaker-metrics", region_name="eu-west-1")
resp = client.batch_put_metrics(
TrialComponentName=trial_component_name,
MetricData=[{
'MetricName': 'some-metric-name',
'Timestamp': datetime(2015, 1, 1),
'Step': 0,
'Value': 123.0,
}]
MetricData=[
{
"MetricName": "some-metric-name",
"Timestamp": datetime.datetime(2015, 1, 1),
"Step": 0,
"Value": 123.0,
}
],
)

assert resp.get("Errors") is not None
assert resp["Errors"][0]["Code"] == "VALIDATION_ERROR"

2 changes: 1 addition & 1 deletion tests/test_sagemakermetrics/test_server.py
Expand Up @@ -10,4 +10,4 @@ def test_sagemakermetrics_batch_put_metrics():
resp = test_client.put("/BatchPutMetrics")

assert resp.status_code == 200
assert "VALIDATION_ERROR" in str(resp.data)
assert "VALIDATION_ERROR" in str(resp.data)

0 comments on commit e5cf1af

Please sign in to comment.