diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java index 62eec13af008a..9cc10a555f288 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java @@ -12,8 +12,6 @@ import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.predicate.Negatable; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -22,7 +20,7 @@ import java.time.ZoneId; import java.util.Map; -public class Equals extends EsqlBinaryComparison implements Negatable { +public class Equals extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.BOOLEAN, EqualsBoolsEvaluator.Factory::new), Map.entry(DataTypes.INTEGER, EqualsIntsEvaluator.Factory::new), @@ -41,11 +39,11 @@ public class Equals extends EsqlBinaryComparison implements Negatable evaluatorMap; + private final BinaryComparisonOperation functionType; + + @FunctionalInterface + public interface BinaryOperatorConstructor { + EsqlBinaryComparison apply(Source source, Expression lhs, Expression rhs); + } + + public enum BinaryComparisonOperation implements Writeable { + + EQ(0, "==", BinaryComparisonProcessor.BinaryComparisonOperation.EQ, Equals::new), + // id 1 reserved for NullEquals + NEQ(2, "!=", BinaryComparisonProcessor.BinaryComparisonOperation.NEQ, NotEquals::new), + GT(3, ">", BinaryComparisonProcessor.BinaryComparisonOperation.GT, GreaterThan::new), + GTE(4, ">=", BinaryComparisonProcessor.BinaryComparisonOperation.GTE, GreaterThanOrEqual::new), + LT(5, "<", BinaryComparisonProcessor.BinaryComparisonOperation.LT, LessThan::new), + LTE(6, "<=", BinaryComparisonProcessor.BinaryComparisonOperation.LTE, LessThanOrEqual::new); + + private final int id; + private final String symbol; + // Temporary mapping to the old enum, to satisfy the superclass constructor signature. + private final BinaryComparisonProcessor.BinaryComparisonOperation shim; + private final BinaryOperatorConstructor constructor; + + BinaryComparisonOperation( + int id, + String symbol, + BinaryComparisonProcessor.BinaryComparisonOperation shim, + BinaryOperatorConstructor constructor + ) { + this.id = id; + this.symbol = symbol; + this.shim = shim; + this.constructor = constructor; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(id); + } + + public static BinaryComparisonOperation readFromStream(StreamInput in) throws IOException { + int id = in.readVInt(); + for (BinaryComparisonOperation op : values()) { + if (op.id == id) { + return op; + } + } + throw new IOException("No BinaryComparisonOperation found for id [" + id + "]"); + } + + public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) { + return constructor.apply(source, lhs, rhs); + } + } + protected EsqlBinaryComparison( Source source, Expression left, Expression right, - /* TODO: BinaryComparisonOperator is an enum with a bunch of functionality we don't really want. We should extract an interface and - create a symbol only version like we did for BinaryArithmeticOperation. Ideally, they could be the same class. - */ - BinaryComparisonProcessor.BinaryComparisonOperation operation, + BinaryComparisonOperation operation, Map evaluatorMap ) { this(source, left, right, operation, null, evaluatorMap); @@ -49,13 +105,18 @@ protected EsqlBinaryComparison( Source source, Expression left, Expression right, - BinaryComparisonProcessor.BinaryComparisonOperation operation, + BinaryComparisonOperation operation, // TODO: We are definitely not doing the right thing with this zoneId ZoneId zoneId, Map evaluatorMap ) { - super(source, left, right, operation, zoneId); + super(source, left, right, operation.shim, zoneId); this.evaluatorMap = evaluatorMap; + this.functionType = operation; + } + + public BinaryComparisonOperation getFunctionType() { + return functionType; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/GreaterThan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/GreaterThan.java index 3eca0e858acbf..09fb32add0f18 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/GreaterThan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/GreaterThan.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.predicate.Negatable; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -21,7 +19,7 @@ import java.time.ZoneId; import java.util.Map; -public class GreaterThan extends EsqlBinaryComparison implements Negatable { +public class GreaterThan extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.INTEGER, GreaterThanIntsEvaluator.Factory::new), Map.entry(DataTypes.DOUBLE, GreaterThanDoublesEvaluator.Factory::new), @@ -35,11 +33,11 @@ public class GreaterThan extends EsqlBinaryComparison implements Negatable { +public class GreaterThanOrEqual extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.INTEGER, GreaterThanOrEqualIntsEvaluator.Factory::new), Map.entry(DataTypes.DOUBLE, GreaterThanOrEqualDoublesEvaluator.Factory::new), @@ -35,11 +33,11 @@ public class GreaterThanOrEqual extends EsqlBinaryComparison implements Negatabl ); public GreaterThanOrEqual(Source source, Expression left, Expression right) { - super(source, left, right, BinaryComparisonProcessor.BinaryComparisonOperation.GTE, evaluatorMap); + super(source, left, right, BinaryComparisonOperation.GTE, evaluatorMap); } public GreaterThanOrEqual(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonProcessor.BinaryComparisonOperation.GTE, zoneId, evaluatorMap); + super(source, left, right, BinaryComparisonOperation.GTE, zoneId, evaluatorMap); } @Override @@ -63,7 +61,7 @@ public LessThan negate() { } @Override - public BinaryComparison reverse() { + public EsqlBinaryComparison reverse() { return new LessThanOrEqual(source(), left(), right(), zoneId()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/LessThan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/LessThan.java index 6b82df1d67da6..1649706a643c3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/LessThan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/LessThan.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.predicate.Negatable; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -21,7 +19,7 @@ import java.time.ZoneId; import java.util.Map; -public class LessThan extends EsqlBinaryComparison implements Negatable { +public class LessThan extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.INTEGER, LessThanIntsEvaluator.Factory::new), @@ -35,8 +33,12 @@ public class LessThan extends EsqlBinaryComparison implements Negatable { +public class LessThanOrEqual extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.INTEGER, LessThanOrEqualIntsEvaluator.Factory::new), Map.entry(DataTypes.DOUBLE, LessThanOrEqualDoublesEvaluator.Factory::new), @@ -34,8 +32,12 @@ public class LessThanOrEqual extends EsqlBinaryComparison implements Negatable { +public class NotEquals extends EsqlBinaryComparison implements Negatable { private static final Map evaluatorMap = Map.ofEntries( Map.entry(DataTypes.BOOLEAN, NotEqualsBoolsEvaluator.Factory::new), Map.entry(DataTypes.INTEGER, NotEqualsIntsEvaluator.Factory::new), @@ -41,11 +39,11 @@ public class NotEquals extends EsqlBinaryComparison implements Negatable namedTypeEntries() { // NamedExpressions of(NamedExpression.class, Alias.class, PlanNamedTypes::writeAlias, PlanNamedTypes::readAlias), // BinaryComparison - of(BinaryComparison.class, Equals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, NullEquals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, NotEquals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, GreaterThan.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, GreaterThanOrEqual.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, LessThan.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), - of(BinaryComparison.class, LessThanOrEqual.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, Equals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, NotEquals.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, GreaterThan.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, GreaterThanOrEqual.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, LessThan.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), + of(EsqlBinaryComparison.class, LessThanOrEqual.class, PlanNamedTypes::writeBinComparison, PlanNamedTypes::readBinComparison), // InsensitiveEquals of( InsensitiveEquals.class, @@ -1199,26 +1196,19 @@ static void writeUnsupportedEsField(PlanStreamOutput out, UnsupportedEsField uns // -- BinaryComparison - static BinaryComparison readBinComparison(PlanStreamInput in, String name) throws IOException { + static EsqlBinaryComparison readBinComparison(PlanStreamInput in, String name) throws IOException { var source = in.readSource(); - var operation = in.readEnum(BinaryComparisonProcessor.BinaryComparisonOperation.class); + EsqlBinaryComparison.BinaryComparisonOperation operation = EsqlBinaryComparison.BinaryComparisonOperation.readFromStream(in); var left = in.readExpression(); var right = in.readExpression(); + // TODO: Remove zoneId entirely var zoneId = in.readOptionalZoneId(); - return switch (operation) { - case EQ -> new Equals(source, left, right, zoneId); - case NULLEQ -> new NullEquals(source, left, right, zoneId); - case NEQ -> new NotEquals(source, left, right, zoneId); - case GT -> new GreaterThan(source, left, right, zoneId); - case GTE -> new GreaterThanOrEqual(source, left, right, zoneId); - case LT -> new LessThan(source, left, right, zoneId); - case LTE -> new LessThanOrEqual(source, left, right, zoneId); - }; - } - - static void writeBinComparison(PlanStreamOutput out, BinaryComparison binaryComparison) throws IOException { + return operation.buildNewInstance(source, left, right); + } + + static void writeBinComparison(PlanStreamOutput out, EsqlBinaryComparison binaryComparison) throws IOException { out.writeSource(binaryComparison.source()); - out.writeEnum(binaryComparison.function()); + binaryComparison.getFunctionType().writeTo(out); out.writeExpression(binaryComparison.left()); out.writeExpression(binaryComparison.right()); out.writeOptionalZoneId(binaryComparison.zoneId()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/EsqlBinaryComparisonTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/EsqlBinaryComparisonTests.java new file mode 100644 index 0000000000000..5e9e702ff8d12 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/EsqlBinaryComparisonTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor; + +import java.io.IOException; +import java.util.List; + +public class EsqlBinaryComparisonTests extends ESTestCase { + + public void testSerializationOfBinaryComparisonOperation() throws IOException { + for (BinaryComparisonOperation op : BinaryComparisonOperation.values()) { + BinaryComparisonOperation newOp = copyWriteable( + op, + new NamedWriteableRegistry(List.of()), + BinaryComparisonOperation::readFromStream + ); + assertEquals(op, newOp); + } + } + + /** + * Test that a serialized + * {@link org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation} + * can be read back as a + * {@link BinaryComparisonOperation} + */ + public void testCompatibleWithQLBinaryComparisonOperation() throws IOException { + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.EQ, BinaryComparisonOperation.EQ); + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.NEQ, BinaryComparisonOperation.NEQ); + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.GT, BinaryComparisonOperation.GT); + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.GTE, BinaryComparisonOperation.GTE); + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.LT, BinaryComparisonOperation.LT); + validateCompatibility(BinaryComparisonProcessor.BinaryComparisonOperation.LTE, BinaryComparisonOperation.LTE); + } + + private static void validateCompatibility( + BinaryComparisonProcessor.BinaryComparisonOperation original, + BinaryComparisonOperation expected + ) throws IOException { + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setTransportVersion(TransportVersion.current()); + output.writeEnum(original); + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), new NamedWriteableRegistry(List.of()))) { + in.setTransportVersion(TransportVersion.current()); + BinaryComparisonOperation newOp = BinaryComparisonOperation.readFromStream(in); + assertEquals(expected, newOp); + } + } + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java index 57d86147a5bba..e22fa3c66384b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.SerializationTestUtils; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.LessThan; @@ -45,7 +46,6 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NullEquals; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -86,7 +86,6 @@ import org.elasticsearch.xpack.ql.expression.function.Function; import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.options.EsSourceOptions; import org.elasticsearch.xpack.ql.plan.logical.Filter; @@ -103,10 +102,8 @@ import org.elasticsearch.xpack.ql.type.KeywordEsField; import org.elasticsearch.xpack.ql.type.TextEsField; import org.elasticsearch.xpack.ql.type.UnsupportedEsField; -import org.elasticsearch.xpack.ql.util.DateUtils; import java.io.IOException; -import java.time.ZoneId; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -334,15 +331,15 @@ public void testBinComparisonSimple() throws IOException { var orig = new Equals(Source.EMPTY, field("foo", DataTypes.DOUBLE), field("bar", DataTypes.DOUBLE)); BytesStreamOutput bso = new BytesStreamOutput(); PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry); - out.writeNamed(BinaryComparison.class, orig); - var deser = (Equals) planStreamInput(bso).readNamed(BinaryComparison.class); + out.writeNamed(EsqlBinaryComparison.class, orig); + var deser = (Equals) planStreamInput(bso).readNamed(EsqlBinaryComparison.class); EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); } public void testBinComparison() { Stream.generate(PlanNamedTypesTests::randomBinaryComparison) .limit(100) - .forEach(obj -> assertNamedType(BinaryComparison.class, obj)); + .forEach(obj -> assertNamedType(EsqlBinaryComparison.class, obj)); } public void testAggFunctionSimple() throws IOException { @@ -582,18 +579,17 @@ static InvalidMappedField randomInvalidMappedField() { ); } - static BinaryComparison randomBinaryComparison() { - int v = randomIntBetween(0, 6); + static EsqlBinaryComparison randomBinaryComparison() { + int v = randomIntBetween(0, 5); var left = field(randomName(), randomDataType()); var right = field(randomName(), randomDataType()); return switch (v) { - case 0 -> new Equals(Source.EMPTY, left, right, zoneIdOrNull()); - case 1 -> new NullEquals(Source.EMPTY, left, right, zoneIdOrNull()); - case 2 -> new NotEquals(Source.EMPTY, left, right, zoneIdOrNull()); - case 3 -> new GreaterThan(Source.EMPTY, left, right, zoneIdOrNull()); - case 4 -> new GreaterThanOrEqual(Source.EMPTY, left, right, zoneIdOrNull()); - case 5 -> new LessThan(Source.EMPTY, left, right, zoneIdOrNull()); - case 6 -> new LessThanOrEqual(Source.EMPTY, left, right, zoneIdOrNull()); + case 0 -> new Equals(Source.EMPTY, left, right); + case 1 -> new NotEquals(Source.EMPTY, left, right); + case 2 -> new GreaterThan(Source.EMPTY, left, right); + case 3 -> new GreaterThanOrEqual(Source.EMPTY, left, right); + case 4 -> new LessThan(Source.EMPTY, left, right); + case 5 -> new LessThanOrEqual(Source.EMPTY, left, right); default -> throw new AssertionError(v); }; } @@ -635,10 +631,6 @@ static NameId nameIdOrNull() { return randomBoolean() ? new NameId() : null; } - static ZoneId zoneIdOrNull() { - return randomBoolean() ? DateUtils.UTC : null; - } - static Nullability randomNullability() { int i = randomInt(2); return switch (i) {