Skip to content

Commit

Permalink
Apply code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
droberts195 committed Mar 19, 2024
1 parent 1e92cfe commit 475e98e
Showing 1 changed file with 55 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
Expand All @@ -27,7 +28,6 @@
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -39,6 +39,7 @@
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand Down Expand Up @@ -122,47 +123,36 @@ protected void doExecute(

GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();

ListenableFuture<Map<String, TrainedModelSizeStats>> modelSizeStatsListener = new ListenableFuture<>();
modelSizeStatsListener.addListener(listener.delegateFailureAndWrap((l, modelSizeStatsByModelId) -> {
responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
l.onResponse(
responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata))
);
}));

ListenableFuture<GetDeploymentStatsAction.Response> deploymentStatsListener = new ListenableFuture<>();
deploymentStatsListener.addListener(listener.delegateFailureAndWrap((delegate, deploymentStats) -> executor.execute(() -> {
// deployment stats for each matching deployment
// not necessarily for all models
responseBuilder.setDeploymentStatsByDeploymentId(
deploymentStats.getStats()
.results()
.stream()
.collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
);
SubscribableListener.<Tuple<Long, Map<String, Set<String>>>>newForked(l -> {
// When the request resource is a deployment find the
// model used in that deployment for the model stats
String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);
logger.debug("Expanded models/deployment Ids request [{}]", idExpression);

int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
modelSizeStats(
responseBuilder.getExpandedModelIdsWithAliases(),
// the request id may contain deployment ids
// It is not an error if these don't match a model id but
// they need to be included in case the deployment id is also
// a model id. Hence, the `matchedDeploymentIds` parameter
trainedModelProvider.expandIds(
idExpression,
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
modelAliasMetadata,
parentTaskId,
modelSizeStatsListener,
numberOfAllocations
matchedDeploymentIds,
l
);
})));

ListenableFuture<List<InferenceStats>> inferenceStatsListener = new ListenableFuture<>();
// inference stats are per model and are only
// persisted for boosted tree models
inferenceStatsListener.addListener(listener.delegateFailureAndWrap((l, inferenceStats) -> executor.execute(() -> {
responseBuilder.setInferenceStatsByModelId(
inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
}).<NodesStatsResponse>andThen((l, tuple) -> {
responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1());
executeAsyncWithOrigin(
client,
ML_ORIGIN,
TransportNodesStatsAction.TYPE,
nodeStatsRequest(clusterService.state(), parentTaskId),
l
);
getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, deploymentStatsListener);
})));

ListenableFuture<NodesStatsResponse> nodesStatsListener = new ListenableFuture<>();
nodesStatsListener.addListener(listener.delegateFailureAndWrap((delegate, nodesStatsResponse) -> executor.execute(() -> {
}).<List<InferenceStats>>andThen(executor, null, (l, nodesStatsResponse) -> {
// find all pipelines whether using the model id,
// alias or deployment id.
Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases()
Expand All @@ -182,46 +172,43 @@ protected void doExecute(
trainedModelProvider.getInferenceStats(
responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]),
parentTaskId,
inferenceStatsListener
l
);
})));

ListenableFuture<Tuple<Long, Map<String, Set<String>>>> idsListener = new ListenableFuture<>();
idsListener.addListener(listener.delegateFailureAndWrap((delegate, tuple) -> {
responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1());
executeAsyncWithOrigin(
client,
ML_ORIGIN,
TransportNodesStatsAction.TYPE,
nodeStatsRequest(clusterService.state(), parentTaskId),
nodesStatsListener
}).<GetDeploymentStatsAction.Response>andThen(executor, null, (l, inferenceStats) -> {
// inference stats are per model and are only
// persisted for boosted tree models
responseBuilder.setInferenceStatsByModelId(
inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
);
getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l);
}).<Map<String, TrainedModelSizeStats>>andThen(executor, null, (l, deploymentStats) -> {
// deployment stats for each matching deployment
// not necessarily for all models
responseBuilder.setDeploymentStatsByDeploymentId(
deploymentStats.getStats()
.results()
.stream()
.collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
);
}));

executor.execute(() -> {
// When the request resource is a deployment find the
// model used in that deployment for the model stats
String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);
logger.debug("Expanded models/deployment Ids request [{}]", idExpression);

// the request id may contain deployment ids
// It is not an error if these don't match a model id but
// they need to be included in case the deployment id is also
// a model id. Hence, the `matchedDeploymentIds` parameter
trainedModelProvider.expandIds(
idExpression,
int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
modelSizeStats(
responseBuilder.getExpandedModelIdsWithAliases(),
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
modelAliasMetadata,
parentTaskId,
matchedDeploymentIds,
idsListener
l,
numberOfAllocations
);
});
}).<GetTrainedModelsStatsAction.Response>andThen((l, modelSizeStatsByModelId) -> {
responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
l.onResponse(
responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata))
);
}).addListener(listener, executor, null);
}

static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
if (Strings.isAllOrWildcard(idExpression)) {
return idExpression;
} else {
Expand Down

0 comments on commit 475e98e

Please sign in to comment.