diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index fc4323e418b72..e3f4f99ee0425 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -184,6 +184,7 @@ static TransportVersion def(int id) { public static final TransportVersion ADD_RESOURCE_ALREADY_UPLOADED_EXCEPTION = def(8_643_00_0); public static final TransportVersion ESQL_MV_ORDERING_SORTED_ASCENDING = def(8_644_00_0); public static final TransportVersion ESQL_PAGE_MAPPING_TO_ITERATOR = def(8_645_00_0); + public static final TransportVersion ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL = def(8_646_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java index 134933deab917..02baa5f715223 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfig.java @@ -148,7 +148,11 @@ public QuestionAnsweringConfig(StreamInput in) throws IOException { vocabularyConfig = new VocabularyConfig(in); tokenization = in.readNamedWriteable(Tokenization.class); resultsField = in.readOptionalString(); - question = in.readOptionalString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) { + question = in.readString(); + } else { + question = in.readOptionalString(); + } } @Override @@ -158,7 +162,11 @@ public void writeTo(StreamOutput out) throws IOException { vocabularyConfig.writeTo(out); out.writeNamedWriteable(tokenization); out.writeOptionalString(resultsField); - out.writeOptionalString(question); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) { + out.writeString(question); + } else { + out.writeOptionalString(question); + } } @Override @@ -171,9 +179,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (resultsField != null) { builder.field(RESULTS_FIELD.getPreferredName(), resultsField); } - if (question != null) { - builder.field(QUESTION.getPreferredName(), question); - } + builder.field(QUESTION.getPreferredName(), question); builder.endObject(); return builder; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java index e787b770b5da5..800eb1a3797b9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/QuestionAnsweringConfigUpdateTests.java @@ -35,6 +35,7 @@ public static QuestionAnsweringConfigUpdate randomUpdate() { ); } + // TODO add test for question public static QuestionAnsweringConfigUpdate mutateForVersion(QuestionAnsweringConfigUpdate instance, TransportVersion version) { if (version.before(TransportVersions.V_8_1_0)) { return new QuestionAnsweringConfigUpdate( @@ -45,6 +46,15 @@ public static QuestionAnsweringConfigUpdate mutateForVersion(QuestionAnsweringCo null ); } + if (version.before(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) { + return new QuestionAnsweringConfigUpdate( + instance.getQuestion() == null ? instance.getQuestion() : null, + instance.getNumTopClasses(), + instance.getMaxAnswerLength(), + instance.getResultsField(), + instance.getTokenizationUpdate() + ); + } return instance; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java index 45571ea2a8238..464c8eac8c9dd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/BertTokenizer.java @@ -236,6 +236,10 @@ public NlpTask.RequestBuilder requestBuilder() { ).buildRequest(requestId, truncate); } + /** + * @param seq cannot be null + * @return InnerTokenization + */ @Override public InnerTokenization innerTokenize(String seq) { List tokenPositionMap = new ArrayList<>(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java index d604b52a55cc4..e884e84faa85d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/RobertaTokenizer.java @@ -178,6 +178,10 @@ TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepToke return new RobertaTokenizationResult.RobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId); } + /** + * @param seq cannot be null + * @return InnerTokenization + */ @Override public InnerTokenization innerTokenize(String seq) { List tokenPositionMap = new ArrayList<>(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java index 3c7d54cd547bf..7a856d8e4735a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java @@ -173,6 +173,10 @@ TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepToke return new XLMRobertaTokenizationResult.XLMRobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId); } + /** + * @param seq cannot be null + * @return InnerTokenization + */ @Override public InnerTokenization innerTokenize(String seq) { List tokenPositionMap = new ArrayList<>();