Skip to content

Commit

Permalink
JAVA-3060: Add vector type, codec + support for parsing CQL type (#1639)
Browse files Browse the repository at this point in the history
  • Loading branch information
absurdfarce committed Jun 5, 2023
1 parent f91979f commit cfeb55f
Show file tree
Hide file tree
Showing 17 changed files with 593 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.data;

import com.datastax.oss.driver.shaded.guava.common.base.Joiner;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterators;
import java.util.Arrays;

/** An n-dimensional vector defined in CQL */
public class CqlVector<T> {

private final ImmutableList<T> values;

private CqlVector(ImmutableList<T> values) {
this.values = values;
}

public static Builder builder() {
return new Builder();
}

public Iterable<T> getValues() {
return values;
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof CqlVector) {
CqlVector that = (CqlVector) o;
return this.values.equals(that.values);
} else {
return false;
}
}

@Override
public int hashCode() {
return Arrays.hashCode(values.toArray());
}

@Override
public String toString() {

String contents = Joiner.on(", ").join(this.values);
return "CqlVector{" + contents + '}';
}

public static class Builder<T> {

private ImmutableList.Builder<T> listBuilder;

private Builder() {
listBuilder = new ImmutableList.Builder<T>();
}

public Builder add(T element) {
listBuilder.add(element);
return this;
}

public Builder add(T... elements) {
listBuilder.addAll(Iterators.forArray(elements));
return this;
}

public Builder addAll(Iterable<T> iter) {
listBuilder.addAll(iter);
return this;
}

public CqlVector<T> build() {
return new CqlVector<T>(listBuilder.build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,24 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) {
return getCqlDuration(firstIndexOf(id));
}

/**
* Returns the value for the first occurrence of {@code id} as a vector.
*
* <p>By default, this works with CQL type {@code vector}.
*
* <p>If an identifier appears multiple times, this can only be used to access the first value.
* For the other ones, use positional getters.
*
* <p>If you want to avoid the overhead of building a {@code CqlIdentifier}, use the variant of
* this method that takes a string argument.
*
* @throws IllegalArgumentException if the id is invalid.
*/
@Nullable
default CqlVector<?> getCqlVector(@NonNull CqlIdentifier id) {
return getCqlVector(firstIndexOf(id));
}

/**
* Returns the value for the first occurrence of {@code id} as a token.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,18 @@ default CqlDuration getCqlDuration(int i) {
return get(i, CqlDuration.class);
}

/**
* Returns the {@code i}th value as a vector.
*
* <p>By default, this works with CQL type {@code vector}.
*
* @throws IndexOutOfBoundsException if the index is invalid.
*/
@Nullable
default CqlVector<?> getCqlVector(int i) {
return get(i, CqlVector.class);
}

/**
* Returns the {@code i}th value as a token.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,24 @@ default CqlDuration getCqlDuration(@NonNull String name) {
return getCqlDuration(firstIndexOf(name));
}

/**
* Returns the value for the first occurrence of {@code name} as a vector.
*
* <p>By default, this works with CQL type {@code vector}.
*
* <p>If an identifier appears multiple times, this can only be used to access the first value.
* For the other ones, use positional getters.
*
* <p>This method deals with case sensitivity in the way explained in the documentation of {@link
* AccessibleByName}.
*
* @throws IllegalArgumentException if the name is invalid.
*/
@Nullable
default CqlVector<?> getCqlVector(@NonNull String name) {
return getCqlVector(firstIndexOf(name));
}

/**
* Returns the value for the first occurrence of {@code name} as a token.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,27 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v)
return result;
}

/**
* Sets the value for all occurrences of {@code id} to the provided duration.
*
* <p>By default, this works with CQL type {@code vector}.
*
* <p>If you want to avoid the overhead of building a {@code CqlIdentifier}, use the variant of
* this method that takes a string argument.
*
* @throws IllegalArgumentException if the id is invalid.
*/
@NonNull
@CheckReturnValue
default SelfT setCqlVector(@NonNull CqlIdentifier id, @Nullable CqlVector<?> v) {
SelfT result = null;
for (Integer i : allIndicesOf(id)) {
result = (result == null ? this : result).setCqlVector(i, v);
}
assert result != null; // allIndices throws if there are no results
return result;
}

/**
* Sets the value for all occurrences of {@code id} to the provided token.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,19 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) {
return set(i, v, CqlDuration.class);
}

/**
* Sets the {@code i}th value to the provided duration.
*
* <p>By default, this works with CQL type {@code vector}.
*
* @throws IndexOutOfBoundsException if the index is invalid.
*/
@NonNull
@CheckReturnValue
default SelfT setCqlVector(int i, @Nullable CqlVector<?> v) {
return set(i, v, CqlVector.class);
}

/**
* Sets the {@code i}th value to the provided token.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,27 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) {
return result;
}

/**
* Sets the value for all occurrences of {@code name} to the provided duration.
*
* <p>By default, this works with CQL type {@code vector}.
*
* <p>This method deals with case sensitivity in the way explained in the documentation of {@link
* AccessibleByName}.
*
* @throws IllegalArgumentException if the name is invalid.
*/
@NonNull
@CheckReturnValue
default SelfT setCqlVector(@NonNull String name, @Nullable CqlVector<?> v) {
SelfT result = null;
for (Integer i : allIndicesOf(name)) {
result = (result == null ? this : result).setCqlVector(i, v);
}
assert result != null; // allIndices throws if there are no results
return result;
}

/**
* Sets the value for all occurrences of {@code name} to the provided token.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.type;

import com.datastax.oss.driver.api.core.detach.AttachmentPoint;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.Objects;

public class CqlVectorType implements CustomType {

public static final String CQLVECTOR_CLASS_NAME = "org.apache.cassandra.db.marshal.VectorType";

private final DataType subtype;
private final int dimensions;

public CqlVectorType(DataType subtype, int dimensions) {

this.dimensions = dimensions;
this.subtype = subtype;
}

public int getDimensions() {
return this.dimensions;
}

public DataType getSubtype() {
return this.subtype;
}

@NonNull
@Override
public String getClassName() {
return CQLVECTOR_CLASS_NAME;
}

@NonNull
@Override
public String asCql(boolean includeFrozen, boolean pretty) {
return String.format("'%s(%d)'", getClassName(), getDimensions());
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof CqlVectorType) {
CqlVectorType that = (CqlVectorType) o;
return that.subtype.equals(this.subtype) && that.dimensions == this.dimensions;
} else {
return false;
}
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), subtype, dimensions);
}

@Override
public String toString() {
return String.format("CqlVector(%s, %d)", getSubtype(), getDimensions());
}

@Override
public boolean isDetached() {
return false;
}

@Override
public void attach(@NonNull AttachmentPoint attachmentPoint) {
// nothing to do
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@
*/
package com.datastax.oss.driver.api.core.type;

import com.datastax.oss.driver.api.core.detach.AttachmentPoint;
import com.datastax.oss.driver.api.core.detach.Detachable;
import com.datastax.oss.driver.internal.core.metadata.schema.parsing.DataTypeClassNameParser;
import com.datastax.oss.driver.internal.core.type.DefaultCustomType;
import com.datastax.oss.driver.internal.core.type.DefaultListType;
import com.datastax.oss.driver.internal.core.type.DefaultMapType;
import com.datastax.oss.driver.internal.core.type.DefaultSetType;
import com.datastax.oss.driver.internal.core.type.DefaultTupleType;
import com.datastax.oss.driver.internal.core.type.PrimitiveType;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.Arrays;
import java.util.List;

/** Constants and factory methods to obtain data type instances. */
public class DataTypes {
Expand All @@ -51,14 +55,26 @@ public class DataTypes {
public static final DataType TINYINT = new PrimitiveType(ProtocolConstants.DataType.TINYINT);
public static final DataType DURATION = new PrimitiveType(ProtocolConstants.DataType.DURATION);

private static final DataTypeClassNameParser classNameParser = new DataTypeClassNameParser();
private static final Splitter paramSplitter = Splitter.on(',').trimResults();

@NonNull
public static DataType custom(@NonNull String className) {

// In protocol v4, duration is implemented as a custom type
if ("org.apache.cassandra.db.marshal.DurationType".equals(className)) {
return DURATION;
} else {
return new DefaultCustomType(className);
if (className.equals("org.apache.cassandra.db.marshal.DurationType")) return DURATION;

/* Vector support is currently implemented as a custom type but is also parameterized */
if (className.startsWith(CqlVectorType.CQLVECTOR_CLASS_NAME)) {
List<String> params =
paramSplitter.splitToList(
className.substring(
CqlVectorType.CQLVECTOR_CLASS_NAME.length() + 1, className.length() - 1));
DataType subType = classNameParser.parse(params.get(0), AttachmentPoint.NONE);
int dimensions = Integer.parseInt(params.get(1));
return new CqlVectorType(subType, dimensions);
}
return new DefaultCustomType(className);
}

@NonNull
Expand Down Expand Up @@ -118,4 +134,8 @@ public static MapType frozenMapOf(@NonNull DataType keyType, @NonNull DataType v
public static TupleType tupleOf(@NonNull DataType... componentTypes) {
return new DefaultTupleType(ImmutableList.copyOf(Arrays.asList(componentTypes)));
}

public static CqlVectorType vectorOf(DataType subtype, int dimensions) {
return new CqlVectorType(subtype, dimensions);
}
}

0 comments on commit cfeb55f

Please sign in to comment.