Skip to content

Commit

Permalink
Address compile errors after vector api changes upstream (#113766)
Browse files Browse the repository at this point in the history
Our lucene_snapshot branch requires updating after apache/lucene#13779
  • Loading branch information
javanna committed Sep 30, 2024
1 parent b6532e6 commit 63e524d
Show file tree
Hide file tree
Showing 22 changed files with 285 additions and 230 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -217,19 +217,17 @@ public float squareDistanceScalar() {
return 1 / (1f + adjustedDistance);
}

RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
var slice = in.slice("values", 0, in.length());
return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, sq, false, sim, null, slice);
}

RandomVectorScorerSupplier luceneScoreSupplier(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim)
throws IOException {
RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException {
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values);
}

RandomVectorScorer luceneScorer(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec)
throws IOException {
RandomVectorScorer luceneScorer(QuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec) throws IOException {
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;

import java.util.Optional;

Expand All @@ -39,7 +39,7 @@ static Optional<VectorScorerFactory> instance() {
Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float scoreCorrectionConstant
);

Expand All @@ -52,9 +52,5 @@ Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
* @param queryVector the query vector
* @return an optional containing the vector scorer, or empty
*/
Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
);
Optional<RandomVectorScorer> getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;

import java.util.Optional;

Expand All @@ -25,7 +25,7 @@ final class VectorScorerFactoryImpl implements VectorScorerFactory {
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float scoreCorrectionConstant
) {
throw new UnsupportedOperationException("should not reach here");
Expand All @@ -34,7 +34,7 @@ public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
@Override
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float[] queryVector
) {
throw new UnsupportedOperationException("should not reach here");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.simdvec.internal.Int7SQVectorScorer;
import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.DotProductSupplier;
Expand All @@ -38,7 +38,7 @@ private VectorScorerFactoryImpl() {}
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float scoreCorrectionConstant
) {
input = FilterIndexInput.unwrapOnlyTest(input);
Expand All @@ -57,7 +57,7 @@ public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
@Override
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float[] queryVector
) {
return Int7SQVectorScorer.create(sim, values, queryVector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;

import java.util.Optional;

public final class Int7SQVectorScorer {

// Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+
public static Optional<RandomVectorScorer> create(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
public static Optional<RandomVectorScorer> create(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector) {
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;

import java.io.IOException;
Expand All @@ -31,12 +31,12 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
final int maxOrd;
final float scoreCorrectionConstant;
final MemorySegmentAccessInput input;
final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds
final QuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds
final ScalarQuantizedVectorSimilarity fallbackScorer;

protected Int7SQVectorScorerSupplier(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
QuantizedByteVectorValues values,
float scoreCorrectionConstant,
ScalarQuantizedVectorSimilarity fallbackScorer
) {
Expand Down Expand Up @@ -104,11 +104,7 @@ public float score(int node) throws IOException {

public static final class EuclideanSupplier extends Int7SQVectorScorerSupplier {

public EuclideanSupplier(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
) {
public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, fromVectorSimilarity(EUCLIDEAN, scoreCorrectionConstant, BITS));
}

Expand All @@ -127,11 +123,7 @@ public EuclideanSupplier copy() {

public static final class DotProductSupplier extends Int7SQVectorScorerSupplier {

public DotProductSupplier(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
) {
public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, fromVectorSimilarity(DOT_PRODUCT, scoreCorrectionConstant, BITS));
}

Expand All @@ -151,11 +143,7 @@ public DotProductSupplier copy() {

public static final class MaxInnerProductSupplier extends Int7SQVectorScorerSupplier {

public MaxInnerProductSupplier(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
) {
public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, fromVectorSimilarity(MAXIMUM_INNER_PRODUCT, scoreCorrectionConstant, BITS));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

import java.io.IOException;
Expand All @@ -35,11 +35,7 @@ public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.Abstr
byte[] scratch;

/** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */
public static Optional<RandomVectorScorer> create(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
public static Optional<RandomVectorScorer> create(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector) {
checkDimensions(queryVector.length, values.dimension());
var input = values.getSlice();
if (input == null) {
Expand All @@ -63,12 +59,7 @@ public static Optional<RandomVectorScorer> create(
};
}

Int7SQVectorScorer(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
byte[] queryVector,
float queryCorrection
) {
Int7SQVectorScorer(MemorySegmentAccessInput input, QuantizedByteVectorValues values, byte[] queryVector, float queryCorrection) {
super(values);
this.input = input;
assert queryVector.length == values.getVectorByteLength();
Expand Down Expand Up @@ -105,7 +96,7 @@ final void checkOrdinal(int ord) {
}

public static final class DotProductScorer extends Int7SQVectorScorer {
public DotProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
public DotProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) {
super(in, values, query, correction);
}

Expand All @@ -122,7 +113,7 @@ public float score(int node) throws IOException {
}

public static final class EuclideanScorer extends Int7SQVectorScorer {
public EuclideanScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
public EuclideanScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) {
super(in, values, query, correction);
}

Expand All @@ -136,7 +127,7 @@ public float score(int node) throws IOException {
}

public static final class MaxInnerProductScorer extends Int7SQVectorScorer {
public MaxInnerProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float corr) {
public MaxInnerProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float corr) {
super(in, values, query, corr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

import java.io.IOException;
Expand Down Expand Up @@ -431,14 +431,13 @@ public Optional<Throwable> call() {
}
}

RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
var slice = in.slice("values", 0, in.length());
return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, sq, false, sim, null, slice);
}

RandomVectorScorerSupplier luceneScoreSupplier(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim)
throws IOException {
RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException {
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexCommit;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
Expand Down Expand Up @@ -544,48 +545,50 @@ void analyzeKnnVectors(SegmentReader reader, IndexDiskUsageStats stats) throws I
if (field.getVectorDimension() > 0) {
switch (field.getVectorEncoding()) {
case BYTE -> {
iterateDocValues(reader.maxDoc(), () -> vectorReader.getByteVectorValues(field.name), vectors -> {
iterateDocValues(reader.maxDoc(), () -> vectorReader.getByteVectorValues(field.name).iterator(), vectors -> {
cancellationChecker.logEvent();
vectors.vectorValue();
vectors.index();
});

// do a couple of randomized searches to figure out min and max offsets of index file
ByteVectorValues vectorValues = vectorReader.getByteVectorValues(field.name);
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
final KnnCollector collector = new TopKnnCollector(
Math.max(1, Math.min(100, vectorValues.size() - 1)),
Integer.MAX_VALUE
);
int numDocsToVisit = reader.maxDoc() < 10 ? reader.maxDoc() : 10 * (int) Math.log10(reader.maxDoc());
int skipFactor = Math.max(reader.maxDoc() / numDocsToVisit, 1);
for (int i = 0; i < reader.maxDoc(); i += skipFactor) {
if ((i = vectorValues.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) {
if ((i = iterator.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
cancellationChecker.checkForCancellation();
vectorReader.search(field.name, vectorValues.vectorValue(), collector, null);
vectorReader.search(field.name, vectorValues.vectorValue(iterator.index()), collector, null);
}
stats.addKnnVectors(field.name, directory.getBytesRead());
}
case FLOAT32 -> {
iterateDocValues(reader.maxDoc(), () -> vectorReader.getFloatVectorValues(field.name), vectors -> {
iterateDocValues(reader.maxDoc(), () -> vectorReader.getFloatVectorValues(field.name).iterator(), vectors -> {
cancellationChecker.logEvent();
vectors.vectorValue();
vectors.index();
});

// do a couple of randomized searches to figure out min and max offsets of index file
FloatVectorValues vectorValues = vectorReader.getFloatVectorValues(field.name);
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
final KnnCollector collector = new TopKnnCollector(
Math.max(1, Math.min(100, vectorValues.size() - 1)),
Integer.MAX_VALUE
);
int numDocsToVisit = reader.maxDoc() < 10 ? reader.maxDoc() : 10 * (int) Math.log10(reader.maxDoc());
int skipFactor = Math.max(reader.maxDoc() / numDocsToVisit, 1);
for (int i = 0; i < reader.maxDoc(); i += skipFactor) {
if ((i = vectorValues.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) {
if ((i = iterator.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
cancellationChecker.checkForCancellation();
vectorReader.search(field.name, vectorValues.vectorValue(), collector, null);
vectorReader.search(field.name, vectorValues.vectorValue(iterator.index()), collector, null);
}
stats.addKnnVectors(field.name, directory.getBytesRead());
}
Expand Down
Loading

0 comments on commit 63e524d

Please sign in to comment.