Skip to content

Commit

Permalink
adding vector find ops
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkWolters committed May 10, 2024
1 parent e6d2c35 commit 101e9d9
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 0 deletions.
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.nosqlbench.adapter.dataapi;

import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Filter;
import com.datastax.astra.client.model.FindOptions;
import com.datastax.astra.client.model.Projection;
import com.datastax.astra.client.model.Sort;
import io.nosqlbench.adapter.dataapi.opdispensers.DataApiOpDispenser;
import io.nosqlbench.adapter.dataapi.ops.DataApiBaseOp;
import io.nosqlbench.adapter.dataapi.ops.DataApiFindVectorFilterOp;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.function.LongFunction;

public class DataApiFindVectorFilterOpDispenser extends DataApiOpDispenser {
private static final Logger logger = LogManager.getLogger(DataApiFindVectorFilterOpDispenser.class);
private final LongFunction<DataApiFindVectorFilterOp> opFunction;
public DataApiFindVectorFilterOpDispenser(DataApiDriverAdapter adapter, ParsedOp op, LongFunction<String> targetFunction) {
super(adapter, op, targetFunction);
this.opFunction = createOpFunction(op);
}

private LongFunction<DataApiFindVectorFilterOp> createOpFunction(ParsedOp op) {
return (l) -> {
Database db = spaceFunction.apply(l).getDatabase();
float[] vector = getVectorValues(op, l);
Filter filter = getFilterFromOp(op, l);
int limit = getLimit(op, l);
return new DataApiFindVectorFilterOp(
db,
db.getCollection(targetFunction.apply(l)),
vector,
limit,
filter
);
};
}

private int getLimit(ParsedOp op, long l) {
return op.getConfigOr("limit", 100, l);
}

private FindOptions getFindOptions(ParsedOp op, long l) {
FindOptions options = new FindOptions();
Sort sort = getSortFromOp(op, l);
float[] vector = getVectorValues(op, l);
if (sort != null) {
options = vector != null ? options.sort(vector, sort) : options.sort(sort);
} else if (vector != null) {
options = options.sort(vector);
}
Projection[] projection = getProjectionFromOp(op, l);
if (projection != null) {
options = options.projection(projection);
}
options.setIncludeSimilarity(true);
return options;
}

@Override
public DataApiBaseOp getOp(long value) {
return opFunction.apply(value);
}
}
Expand Up @@ -47,10 +47,13 @@ public OpDispenser<? extends DataApiBaseOp> apply(ParsedOp op) {
case create_collection -> new DataApiCreateCollectionOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_many -> new DataApiInsertManyOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_one -> new DataApiInsertOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_one_vector -> new DataApiInsertOneVectorOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find -> new DataApiFindOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one -> new DataApiFindOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one_and_delete -> new DataApiFindOneAndDeleteOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one_and_update -> new DataApiFindOneAndUpdateOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_vector -> new DataApiFindVectorOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_vector_filter -> new DataApiFindVectorFilterOpDispenser(adapter, op, typeAndTarget.targetFunction);
case update_one -> new DataApiUpdateOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case update_many -> new DataApiUpdateManyOpDispenser(adapter, op, typeAndTarget.targetFunction);
case delete_one -> new DataApiDeleteOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
Expand Down
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.nosqlbench.adapter.dataapi.opdispensers;

import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.FindOptions;
import com.datastax.astra.client.model.Projection;
import com.datastax.astra.client.model.Sort;
import io.nosqlbench.adapter.dataapi.DataApiDriverAdapter;
import io.nosqlbench.adapter.dataapi.ops.DataApiBaseOp;
import io.nosqlbench.adapter.dataapi.ops.DataApiFindVectorOp;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.function.LongFunction;

public class DataApiFindVectorOpDispenser extends DataApiOpDispenser {
private static final Logger logger = LogManager.getLogger(DataApiFindVectorOpDispenser.class);
private final LongFunction<DataApiFindVectorOp> opFunction;
public DataApiFindVectorOpDispenser(DataApiDriverAdapter adapter, ParsedOp op, LongFunction<String> targetFunction) {
super(adapter, op, targetFunction);
this.opFunction = createOpFunction(op);
}

private LongFunction<DataApiFindVectorOp> createOpFunction(ParsedOp op) {
return (l) -> {
Database db = spaceFunction.apply(l).getDatabase();
float[] vector = getVectorValues(op, l);
int limit = getLimit(op, l);
return new DataApiFindVectorOp(
db,
db.getCollection(targetFunction.apply(l)),
vector,
limit
);
};
}

private int getLimit(ParsedOp op, long l) {
return op.getConfigOr("limit", 100, l);
}

private FindOptions getFindOptions(ParsedOp op, long l) {
FindOptions options = new FindOptions();
Sort sort = getSortFromOp(op, l);
float[] vector = getVectorValues(op, l);
if (sort != null) {
options = vector != null ? options.sort(vector, sort) : options.sort(sort);
} else if (vector != null) {
options = options.sort(vector);
}
Projection[] projection = getProjectionFromOp(op, l);
if (projection != null) {
options = options.projection(projection);
}
options.setIncludeSimilarity(true);
return options;
}

@Override
public DataApiBaseOp getOp(long value) {
return opFunction.apply(value);
}
}
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.nosqlbench.adapter.dataapi.ops;

import com.datastax.astra.client.Collection;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Document;
import com.datastax.astra.client.model.Filter;

public class DataApiFindVectorFilterOp extends DataApiBaseOp {
private final Collection<Document> collection;
private final float[] vector;
private final int limit;
private final Filter filter;

public DataApiFindVectorFilterOp(Database db, Collection<Document> collection, float[] vector, int limit, Filter filter) {
super(db);
this.collection = collection;
this.vector = vector;
this.limit = limit;
this.filter = filter;
}

@Override
public Object apply(long value) {
return collection.find(filter, vector, limit);
}
}
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.nosqlbench.adapter.dataapi.ops;

import com.datastax.astra.client.Collection;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Document;
import com.datastax.astra.client.model.Filter;
import com.datastax.astra.client.model.FindOptions;

public class DataApiFindVectorOp extends DataApiBaseOp {
private final Collection<Document> collection;
private final float[] vector;
private final int limit;

public DataApiFindVectorOp(Database db, Collection<Document> collection, float[] vector, int limit) {
super(db);
this.collection = collection;
this.vector = vector;
this.limit = limit;
}

@Override
public Object apply(long value) {
return collection.find(vector, limit);
}
}
Expand Up @@ -20,10 +20,13 @@ public enum DataApiOpType {
create_collection,
insert_many,
insert_one,
insert_one_vector,
find,
find_one,
find_one_and_delete,
find_one_and_update,
find_vector,
find_vector_filter,
update_one,
update_many,
delete_one,
Expand Down

0 comments on commit 101e9d9

Please sign in to comment.