Skip to content

Commit

Permalink
Require question to be non-null in QuestionAnsweringConfig. question …
Browse files Browse the repository at this point in the history
…can still be null in previous versions, but will cause NPE in tokenizer if used
  • Loading branch information
maxhniebergall committed Apr 26, 2024
1 parent 4664ced commit ef304bf
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 5 deletions.
Expand Up @@ -184,6 +184,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ADD_RESOURCE_ALREADY_UPLOADED_EXCEPTION = def(8_643_00_0);
public static final TransportVersion ESQL_MV_ORDERING_SORTED_ASCENDING = def(8_644_00_0);
public static final TransportVersion ESQL_PAGE_MAPPING_TO_ITERATOR = def(8_645_00_0);
public static final TransportVersion ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL = def(8_646_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Expand Up @@ -148,7 +148,11 @@ public QuestionAnsweringConfig(StreamInput in) throws IOException {
vocabularyConfig = new VocabularyConfig(in);
tokenization = in.readNamedWriteable(Tokenization.class);
resultsField = in.readOptionalString();
question = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) {
question = in.readString();
} else {
question = in.readOptionalString();
}
}

@Override
Expand All @@ -158,7 +162,11 @@ public void writeTo(StreamOutput out) throws IOException {
vocabularyConfig.writeTo(out);
out.writeNamedWriteable(tokenization);
out.writeOptionalString(resultsField);
out.writeOptionalString(question);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) {
out.writeString(question);
} else {
out.writeOptionalString(question);
}
}

@Override
Expand All @@ -171,9 +179,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (question != null) {
builder.field(QUESTION.getPreferredName(), question);
}
builder.field(QUESTION.getPreferredName(), question);
builder.endObject();
return builder;
}
Expand Down
Expand Up @@ -35,6 +35,7 @@ public static QuestionAnsweringConfigUpdate randomUpdate() {
);
}

// TODO add test for question
public static QuestionAnsweringConfigUpdate mutateForVersion(QuestionAnsweringConfigUpdate instance, TransportVersion version) {
if (version.before(TransportVersions.V_8_1_0)) {
return new QuestionAnsweringConfigUpdate(
Expand All @@ -45,6 +46,15 @@ public static QuestionAnsweringConfigUpdate mutateForVersion(QuestionAnsweringCo
null
);
}
if (version.before(TransportVersions.ML_QUESTION_ANSWERING_CONFIG_REQUIRE_QUESTION_NON_NULL)) {
return new QuestionAnsweringConfigUpdate(
instance.getQuestion() == null ? instance.getQuestion() : null,
instance.getNumTopClasses(),
instance.getMaxAnswerLength(),
instance.getResultsField(),
instance.getTokenizationUpdate()
);
}
return instance;
}

Expand Down
Expand Up @@ -236,6 +236,10 @@ public NlpTask.RequestBuilder requestBuilder() {
).buildRequest(requestId, truncate);
}

/**
* @param seq cannot be null
* @return InnerTokenization
*/
@Override
public InnerTokenization innerTokenize(String seq) {
List<Integer> tokenPositionMap = new ArrayList<>();
Expand Down
Expand Up @@ -178,6 +178,10 @@ TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepToke
return new RobertaTokenizationResult.RobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
}

/**
* @param seq cannot be null
* @return InnerTokenization
*/
@Override
public InnerTokenization innerTokenize(String seq) {
List<Integer> tokenPositionMap = new ArrayList<>();
Expand Down
Expand Up @@ -173,6 +173,10 @@ TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepToke
return new XLMRobertaTokenizationResult.XLMRobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
}

/**
* @param seq cannot be null
* @return InnerTokenization
*/
@Override
public InnerTokenization innerTokenize(String seq) {
List<Integer> tokenPositionMap = new ArrayList<>();
Expand Down

0 comments on commit ef304bf

Please sign in to comment.