Skip to content

Commit

Permalink
Merge pull request #1 from DaveCTurner/2024/03/19/99439-example
Browse files Browse the repository at this point in the history
Suggestions
  • Loading branch information
droberts195 committed Mar 20, 2024
2 parents ae2cb66 + ae4d7c3 commit 98d45f8
Showing 1 changed file with 116 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters;
Expand All @@ -17,8 +18,8 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
Expand All @@ -27,7 +28,6 @@
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -76,7 +76,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByResource;

public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<
public class TransportGetTrainedModelsStatsAction extends TransportAction<
GetTrainedModelsStatsAction.Request,
GetTrainedModelsStatsAction.Response> {

Expand All @@ -96,13 +96,7 @@ public TransportGetTrainedModelsStatsAction(
TrainedModelProvider trainedModelProvider,
Client client
) {
super(
GetTrainedModelsStatsAction.NAME,
transportService,
actionFilters,
GetTrainedModelsStatsAction.Request::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
super(GetTrainedModelsStatsAction.NAME, actionFilters, transportService.getTaskManager());
this.client = client;
this.clusterService = clusterService;
this.trainedModelProvider = trainedModelProvider;
Expand All @@ -114,6 +108,15 @@ protected void doExecute(
Task task,
GetTrainedModelsStatsAction.Request request,
ActionListener<GetTrainedModelsStatsAction.Response> listener
) {
// workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
executor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l)));
}

protected void doExecuteForked(
Task task,
GetTrainedModelsStatsAction.Request request,
ActionListener<GetTrainedModelsStatsAction.Response> listener
) {
final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
final ModelAliasMetadata modelAliasMetadata = ModelAliasMetadata.fromState(clusterService.state());
Expand All @@ -122,90 +125,111 @@ protected void doExecute(

GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();

SubscribableListener.<String>newForked(l -> {
// When the request resource is a deployment find the
// model used in that deployment for the model stats
String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);
l.onResponse(idExpression);
}).<Tuple<Long, Map<String, Set<String>>>>andThen(executor, null, (l, idExpression) -> {
logger.debug("Expanded models/deployment Ids request [{}]", idExpression);

// the request id may contain deployment ids
// It is not an error if these don't match a model id but
// they need to be included in case the deployment id is also
// a model id. Hence, the `matchedDeploymentIds` parameter
trainedModelProvider.expandIds(
idExpression,
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
modelAliasMetadata,
parentTaskId,
matchedDeploymentIds,
l
);
}).<NodesStatsResponse>andThen((l, tuple) -> {
responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1());
executeAsyncWithOrigin(
client,
ML_ORIGIN,
TransportNodesStatsAction.TYPE,
nodeStatsRequest(clusterService.state(), parentTaskId),
l
);
}).<List<InferenceStats>>andThen(executor, null, (l, nodesStatsResponse) -> {
// find all pipelines whether using the model id,
// alias or deployment id.
Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases()
.entrySet()
.stream()
.flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey())))
.collect(Collectors.toSet());
allPossiblePipelineReferences.addAll(matchedDeploymentIds);

Map<String, Set<String>> pipelineIdsByResource = pipelineIdsByResource(clusterService.state(), allPossiblePipelineReferences);
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
nodesStatsResponse,
modelAliasMetadata,
pipelineIdsByResource
);
responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
trainedModelProvider.getInferenceStats(
responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]),
parentTaskId,
l
);
}).<GetDeploymentStatsAction.Response>andThen(executor, null, (l, inferenceStats) -> {
// inference stats are per model and are only
// persisted for boosted tree models
responseBuilder.setInferenceStatsByModelId(
inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
);
getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l);
}).<Map<String, TrainedModelSizeStats>>andThen(executor, null, (l, deploymentStats) -> {
// deployment stats for each matching deployment
// not necessarily for all models
responseBuilder.setDeploymentStatsByDeploymentId(
deploymentStats.getStats()
.results()
SubscribableListener

.<Tuple<Long, Map<String, Set<String>>>>newForked(l -> {
// When the request resource is a deployment find the model used in that deployment for the model stats
final var idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata);

logger.debug("Expanded models/deployment Ids request [{}]", idExpression);

// the request id may contain deployment ids
// It is not an error if these don't match a model id but
// they need to be included in case the deployment id is also
// a model id. Hence, the `matchedDeploymentIds` parameter
trainedModelProvider.expandIds(
idExpression,
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
modelAliasMetadata,
parentTaskId,
matchedDeploymentIds,
l
);
})
.andThenAccept(tuple -> responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()))

.<NodesStatsResponse>andThen(
(l, ignored) -> executeAsyncWithOrigin(
client,
ML_ORIGIN,
TransportNodesStatsAction.TYPE,
nodeStatsRequest(clusterService.state(), parentTaskId),
l
)
)
.<List<InferenceStats>>andThen(executor, null, (l, nodesStatsResponse) -> {
// find all pipelines whether using the model id, alias or deployment id.
Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases()
.entrySet()
.stream()
.collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
);
.flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey())))
.collect(Collectors.toSet());
allPossiblePipelineReferences.addAll(matchedDeploymentIds);

int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
modelSizeStats(
responseBuilder.getExpandedModelIdsWithAliases(),
request.isAllowNoResources(),
parentTaskId,
l,
numberOfAllocations
);
}).<GetTrainedModelsStatsAction.Response>andThen((l, modelSizeStatsByModelId) -> {
responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
l.onResponse(
responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata))
);
}).addListener(listener, executor, null);
Map<String, Set<String>> pipelineIdsByResource = pipelineIdsByResource(
clusterService.state(),
allPossiblePipelineReferences
);
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
nodesStatsResponse,
modelAliasMetadata,
pipelineIdsByResource
);
responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
trainedModelProvider.getInferenceStats(
responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]),
parentTaskId,
l
);
})
.andThenAccept(
// inference stats are per model and are only persisted for boosted tree models
inferenceStats -> responseBuilder.setInferenceStatsByModelId(
inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
)
)

.<GetDeploymentStatsAction.Response>andThen(
executor,
null,
(l, ignored) -> getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l)
)
.andThenApply(deploymentStats -> {
// deployment stats for each matching deployment not necessarily for all models
responseBuilder.setDeploymentStatsByDeploymentId(
deploymentStats.getStats()
.results()
.stream()
.collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity()))
);
return deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
})

.<Map<String, TrainedModelSizeStats>>andThen(
executor,
null,
(l, numberOfAllocations) -> modelSizeStats(
responseBuilder.getExpandedModelIdsWithAliases(),
request.isAllowNoResources(),
parentTaskId,
l.map(modelSizeStatsByModelId -> {
responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
return null;
}),
numberOfAllocations
)
)
.andThenAccept(responseBuilder::setModelSizeStatsByModelId)

.andThenApply(
ignored -> responseBuilder.build(
modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)
)
)

.addListener(listener, executor, null);
}

static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) {
Expand Down

0 comments on commit 98d45f8

Please sign in to comment.