Skip to content

Commit

Permalink
Fixing lucene snapshot compilation due to new VectorValue interface (#…
Browse files Browse the repository at this point in the history
…108516)

* Fixing lucene snapshot compilation due to new VectorValue interface

* fixing test compilation
  • Loading branch information
benwtrent committed May 10, 2024
1 parent fee12c9 commit 92c48e5
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 0 deletions.
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
Expand Down Expand Up @@ -457,6 +458,8 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
docsWithField.cardinality(),
mergedQuantizationState, // TODO: bits
false, // TODO compress
fieldInfo.getVectorSimilarityFunction(),
vectorsScorer,
quantizationDataInput
)
);
Expand Down Expand Up @@ -649,6 +652,11 @@ public int advance(int target) throws IOException {
curDoc = target;
return docID();
}

@Override
public VectorScorer scorer(float[] floats) throws IOException {
throw new UnsupportedOperationException();
}
}

private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
Expand Down Expand Up @@ -770,6 +778,11 @@ public int dimension() {
public float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
}

@Override
public VectorScorer vectorScorer(float[] floats) throws IOException {
throw new UnsupportedOperationException();
}
}

private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
Expand Down Expand Up @@ -836,6 +849,11 @@ public int advance(int target) throws IOException {
return doc;
}

@Override
public VectorScorer vectorScorer(float[] floats) throws IOException {
throw new UnsupportedOperationException();
}

private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
Expand Down Expand Up @@ -932,5 +950,10 @@ public int nextDoc() throws IOException {
public int advance(int target) throws IOException {
return in.advance(target);
}

@Override
public VectorScorer vectorScorer(float[] floats) throws IOException {
throw new UnsupportedOperationException();
}
}
}
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.VectorScorer;

import java.io.IOException;

Expand Down Expand Up @@ -63,6 +64,11 @@ public int advance(int target) throws IOException {
return in.advance(target);
}

@Override
public VectorScorer scorer(float[] floats) throws IOException {
return in.scorer(floats);
}

public float magnitude() {
return magnitude;
}
Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.suggest.document.CompletionTerms;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -481,6 +482,27 @@ public byte[] vectorValue() throws IOException {
return in.vectorValue();
}

@Override
public VectorScorer scorer(byte[] bytes) throws IOException {
VectorScorer scorer = in.scorer(bytes);
if (scorer == null) {
return null;
}
return new VectorScorer() {
private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation);

@Override
public float score() throws IOException {
return scorer.score();
}

@Override
public DocIdSetIterator iterator() {
return iterator;
}
};
}

@Override
public int docID() {
return in.docID();
Expand Down Expand Up @@ -531,11 +553,72 @@ public int nextDoc() throws IOException {
return nextDoc;
}

@Override
public VectorScorer scorer(float[] target) throws IOException {
VectorScorer scorer = in.scorer(target);
if (scorer == null) {
return null;
}
return new VectorScorer() {
private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation);

@Override
public float score() throws IOException {
return scorer.score();
}

@Override
public DocIdSetIterator iterator() {
return iterator;
}
};
}

private void checkAndThrowWithSampling() {
if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
this.queryCancellation.checkCancelled();
}
}
}

private static class ExitableDocSetIterator extends DocIdSetIterator {
private int calls;
private final DocIdSetIterator in;
private final QueryCancellation queryCancellation;

private ExitableDocSetIterator(DocIdSetIterator in, QueryCancellation queryCancellation) {
this.in = in;
this.queryCancellation = queryCancellation;
}

@Override
public int docID() {
return in.docID();
}

@Override
public int advance(int target) throws IOException {
final int advance = in.advance(target);
checkAndThrowWithSampling();
return advance;
}

@Override
public int nextDoc() throws IOException {
final int nextDoc = in.nextDoc();
checkAndThrowWithSampling();
return nextDoc;
}

@Override
public long cost() {
return in.cost();
}

private void checkAndThrowWithSampling() {
if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
this.queryCancellation.checkCancelled();
}
}
}
}
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.VectorScorer;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.DenseVector;
Expand Down Expand Up @@ -230,6 +231,11 @@ public int advance(int target) {
}
return index = target;
}

@Override
public VectorScorer scorer(byte[] floats) throws IOException {
throw new UnsupportedOperationException();
}
};
}

Expand Down Expand Up @@ -270,6 +276,11 @@ public int advance(int target) {
}
return index = target;
}

@Override
public VectorScorer scorer(float[] floats) throws IOException {
throw new UnsupportedOperationException();
}
};
}
}

0 comments on commit 92c48e5

Please sign in to comment.