Skip to content

Commit

Permalink
improvements from DK review
Browse files Browse the repository at this point in the history
  • Loading branch information
maxhniebergall committed May 9, 2024
1 parent 2e9d555 commit 5fda71b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 59 deletions.
Expand Up @@ -20,7 +20,6 @@ public static ElasticsearchCluster mixedVersionCluster() {
.withNode(node -> node.version(Version.CURRENT))
.setting("xpack.security.enabled", "false")
.setting("xpack.license.self_generated.type", "trial")
.setting("cluster.routing.rebalance.enable", "none") // disable relocation until we have retry in ESQL
.build();
}
}
Expand Up @@ -57,59 +57,33 @@ public void testCohereEmbeddings() throws IOException {
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED));
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);

final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
final String oldClusterIdFloat = "old-cluster-embeddings-float";
final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";

// queue a response as PUT will call the service
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
// float model
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);

var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints");
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
var embeddingType = serviceSettings.get("embedding_type");
// An upgraded node will report the embedding type as byte, an old node int8
assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));

configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdFloat).get("endpoints");
configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));

assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE);
assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT);

{
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";

cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);

configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdInt8).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte

assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
delete(upgradedClusterIdInt8);
}
{
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);

configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));

assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
delete(upgradedClusterIdFloat);
}

delete(oldClusterIdFloat);
delete(oldClusterIdInt8);
delete(inferenceIdFloat);
delete(inferenceIdInt8);

}

Expand All @@ -132,20 +106,20 @@ public void testRerank() throws IOException {
var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED));
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);

final String oldClusterId = "old-cluster-rerank";
final String inferenceId = "mixed-cluster-rerank";

put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
assertRerank(oldClusterId);
put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
assertRerank(inferenceId);

var configs = (List<Map<String, Object>>) get(TaskType.RERANK, oldClusterId).get("endpoints");
var configs = (List<Map<String, Object>>) get(TaskType.RERANK, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);
assertRerank(inferenceId);

}

Expand Down
Expand Up @@ -50,14 +50,14 @@ public void testHFEmbeddings() throws IOException {
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED));
assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported);

final String oldClusterId = "old-cluster-embeddings";
final String inferenceId = "mixed-cluster-embeddings";

embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
put(oldClusterId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints");
put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("hugging_face", configs.get(0).get("service"));
assertEmbeddingInference(oldClusterId);
assertEmbeddingInference(inferenceId);
}

void assertEmbeddingInference(String inferenceId) throws IOException {
Expand All @@ -71,15 +71,15 @@ public void testElser() throws IOException {
var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED));
assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported);

final String oldClusterId = "old-cluster-elser";
final String inferenceId = "mixed-cluster-elser";
final String upgradedClusterId = "upgraded-cluster-elser";

put(oldClusterId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);

var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, oldClusterId).get("endpoints");
var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("hugging_face", configs.get(0).get("service"));
assertElser(oldClusterId);
assertElser(inferenceId);
}

private void assertElser(String inferenceId) throws IOException {
Expand Down
Expand Up @@ -54,22 +54,22 @@ public void testOpenAiEmbeddings() throws IOException {
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported);

final String oldClusterId = "old-cluster-embeddings";
final String inferenceId = "mixed-cluster-embeddings";

String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
// queue a response as PUT will call the service
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
put(oldClusterId, inferenceConfig, TaskType.TEXT_EMBEDDING);
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);

var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints");
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("openai", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id");
assertTrue("model_id not found in config: " + configs.toString(), modelIdFound);

assertEmbeddingInference(oldClusterId);
assertEmbeddingInference(inferenceId);
}

void assertEmbeddingInference(String inferenceId) throws IOException {
Expand All @@ -83,12 +83,12 @@ public void testOpenAiCompletions() throws IOException {
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported);

final String oldClusterId = "old-cluster-completions";
final String inferenceId = "mixed-cluster-completions";
final String upgradedClusterId = "upgraded-cluster-completions";

put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);
put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);

var configsMap = get(TaskType.COMPLETION, oldClusterId);
var configsMap = get(TaskType.COMPLETION, inferenceId);
logger.warn("Configs: {}", configsMap);
var configs = (List<Map<String, Object>>) configsMap.get("endpoints");
assertThat(configs, hasSize(1));
Expand All @@ -98,7 +98,7 @@ public void testOpenAiCompletions() throws IOException {
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
assertThat(taskSettings.keySet(), empty());

assertCompletionInference(oldClusterId);
assertCompletionInference(inferenceId);
}

void assertCompletionInference(String inferenceId) throws IOException {
Expand Down
Expand Up @@ -120,7 +120,6 @@ public void testOpenAiCompletions() throws IOException {
final String upgradedClusterId = "upgraded-cluster-completions";

if (isOldCluster()) {
// TODO why is put only in old cluster?
put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);

var configs = (List<Map<String, Object>>) get(TaskType.COMPLETION, oldClusterId).get("models");
Expand Down

0 comments on commit 5fda71b

Please sign in to comment.