diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java index 569e8909e1e12..b294fe97c7e7c 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java @@ -19,7 +19,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.core.IOUtils; @@ -217,19 +217,17 @@ public float squareDistanceScalar() { return 1 / (1f + adjustedDistance); } - RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); var slice = in.slice("values", 0, in.length()); return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, sq, false, sim, null, slice); } - RandomVectorScorerSupplier luceneScoreSupplier(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim) - throws IOException { + RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); } - RandomVectorScorer luceneScorer(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec) - throws IOException { + RandomVectorScorer luceneScorer(QuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec) throws IOException { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index e2aea6b3ebd9f..4ed60b2f5e8b2 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -13,7 +13,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import java.util.Optional; @@ -39,7 +39,7 @@ static Optional instance() { Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float scoreCorrectionConstant ); @@ -52,9 +52,5 @@ Optional getInt7SQVectorScorerSupplier( * @param queryVector the query vector * @return an optional containing the vector scorer, or empty */ - Optional getInt7SQVectorScorer( - VectorSimilarityFunction sim, - RandomAccessQuantizedByteVectorValues values, - float[] queryVector - ); + Optional getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index a22d787980252..6248902c32e7a 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -13,7 +13,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import java.util.Optional; @@ -25,7 +25,7 @@ final class VectorScorerFactoryImpl implements VectorScorerFactory { public Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float scoreCorrectionConstant ) { throw new UnsupportedOperationException("should not reach here"); @@ -34,7 +34,7 @@ public Optional getInt7SQVectorScorerSupplier( @Override public Optional getInt7SQVectorScorer( VectorSimilarityFunction sim, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float[] queryVector ) { throw new UnsupportedOperationException("should not reach here"); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index a65fe582087d9..a863d9e3448ca 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -15,7 +15,7 @@ import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.DotProductSupplier; @@ -38,7 +38,7 @@ private VectorScorerFactoryImpl() {} public Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float scoreCorrectionConstant ) { input = FilterIndexInput.unwrapOnlyTest(input); @@ -57,7 +57,7 @@ public Optional getInt7SQVectorScorerSupplier( @Override public Optional getInt7SQVectorScorer( VectorSimilarityFunction sim, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float[] queryVector ) { return Int7SQVectorScorer.create(sim, values, queryVector); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java index 0b41436ce2242..e02df124ad0f0 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java @@ -11,18 +11,14 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import java.util.Optional; public final class Int7SQVectorScorer { // Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+ - public static Optional create( - VectorSimilarityFunction sim, - RandomAccessQuantizedByteVectorValues values, - float[] queryVector - ) { + public static Optional create(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector) { return Optional.empty(); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java index f6d874cd3e728..198e10406056e 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java @@ -12,7 +12,7 @@ import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import java.io.IOException; @@ -31,12 +31,12 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS final int maxOrd; final float scoreCorrectionConstant; final MemorySegmentAccessInput input; - final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds + final QuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds final ScalarQuantizedVectorSimilarity fallbackScorer; protected Int7SQVectorScorerSupplier( MemorySegmentAccessInput input, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float scoreCorrectionConstant, ScalarQuantizedVectorSimilarity fallbackScorer ) { @@ -104,11 +104,7 @@ public float score(int node) throws IOException { public static final class EuclideanSupplier extends Int7SQVectorScorerSupplier { - public EuclideanSupplier( - MemorySegmentAccessInput input, - RandomAccessQuantizedByteVectorValues values, - float scoreCorrectionConstant - ) { + public EuclideanSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) { super(input, values, scoreCorrectionConstant, fromVectorSimilarity(EUCLIDEAN, scoreCorrectionConstant, BITS)); } @@ -127,11 +123,7 @@ public EuclideanSupplier copy() { public static final class DotProductSupplier extends Int7SQVectorScorerSupplier { - public DotProductSupplier( - MemorySegmentAccessInput input, - RandomAccessQuantizedByteVectorValues values, - float scoreCorrectionConstant - ) { + public DotProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) { super(input, values, scoreCorrectionConstant, fromVectorSimilarity(DOT_PRODUCT, scoreCorrectionConstant, BITS)); } @@ -151,11 +143,7 @@ public DotProductSupplier copy() { public static final class MaxInnerProductSupplier extends Int7SQVectorScorerSupplier { - public MaxInnerProductSupplier( - MemorySegmentAccessInput input, - RandomAccessQuantizedByteVectorValues values, - float scoreCorrectionConstant - ) { + public MaxInnerProductSupplier(MemorySegmentAccessInput input, QuantizedByteVectorValues values, float scoreCorrectionConstant) { super(input, values, scoreCorrectionConstant, fromVectorSimilarity(MAXIMUM_INNER_PRODUCT, scoreCorrectionConstant, BITS)); } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java index c9659ea1af9a8..3d0e1e71a3744 100644 --- a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java @@ -15,7 +15,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; import java.io.IOException; @@ -35,11 +35,7 @@ public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.Abstr byte[] scratch; /** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */ - public static Optional create( - VectorSimilarityFunction sim, - RandomAccessQuantizedByteVectorValues values, - float[] queryVector - ) { + public static Optional create(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector) { checkDimensions(queryVector.length, values.dimension()); var input = values.getSlice(); if (input == null) { @@ -63,12 +59,7 @@ public static Optional create( }; } - Int7SQVectorScorer( - MemorySegmentAccessInput input, - RandomAccessQuantizedByteVectorValues values, - byte[] queryVector, - float queryCorrection - ) { + Int7SQVectorScorer(MemorySegmentAccessInput input, QuantizedByteVectorValues values, byte[] queryVector, float queryCorrection) { super(values); this.input = input; assert queryVector.length == values.getVectorByteLength(); @@ -105,7 +96,7 @@ final void checkOrdinal(int ord) { } public static final class DotProductScorer extends Int7SQVectorScorer { - public DotProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) { + public DotProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) { super(in, values, query, correction); } @@ -122,7 +113,7 @@ public float score(int node) throws IOException { } public static final class EuclideanScorer extends Int7SQVectorScorer { - public EuclideanScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) { + public EuclideanScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) { super(in, values, query, correction); } @@ -136,7 +127,7 @@ public float score(int node) throws IOException { } public static final class MaxInnerProductScorer extends Int7SQVectorScorer { - public MaxInnerProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float corr) { + public MaxInnerProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float corr) { super(in, values, query, corr); } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java index db57dc936e794..0f967127f6f2c 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java @@ -21,7 +21,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; import java.io.IOException; @@ -431,14 +431,13 @@ public Optional call() { } } - RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { var sq = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); var slice = in.slice("values", 0, in.length()); return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, sq, false, sim, null, slice); } - RandomVectorScorerSupplier luceneScoreSupplier(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim) - throws IOException { + RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java b/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java index 5acbbb5536560..e668624440351 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/diskusage/IndexDiskUsageAnalyzer.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexCommit; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.PointValues; @@ -544,13 +545,14 @@ void analyzeKnnVectors(SegmentReader reader, IndexDiskUsageStats stats) throws I if (field.getVectorDimension() > 0) { switch (field.getVectorEncoding()) { case BYTE -> { - iterateDocValues(reader.maxDoc(), () -> vectorReader.getByteVectorValues(field.name), vectors -> { + iterateDocValues(reader.maxDoc(), () -> vectorReader.getByteVectorValues(field.name).iterator(), vectors -> { cancellationChecker.logEvent(); - vectors.vectorValue(); + vectors.index(); }); // do a couple of randomized searches to figure out min and max offsets of index file ByteVectorValues vectorValues = vectorReader.getByteVectorValues(field.name); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); final KnnCollector collector = new TopKnnCollector( Math.max(1, Math.min(100, vectorValues.size() - 1)), Integer.MAX_VALUE @@ -558,22 +560,23 @@ void analyzeKnnVectors(SegmentReader reader, IndexDiskUsageStats stats) throws I int numDocsToVisit = reader.maxDoc() < 10 ? reader.maxDoc() : 10 * (int) Math.log10(reader.maxDoc()); int skipFactor = Math.max(reader.maxDoc() / numDocsToVisit, 1); for (int i = 0; i < reader.maxDoc(); i += skipFactor) { - if ((i = vectorValues.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) { + if ((i = iterator.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) { break; } cancellationChecker.checkForCancellation(); - vectorReader.search(field.name, vectorValues.vectorValue(), collector, null); + vectorReader.search(field.name, vectorValues.vectorValue(iterator.index()), collector, null); } stats.addKnnVectors(field.name, directory.getBytesRead()); } case FLOAT32 -> { - iterateDocValues(reader.maxDoc(), () -> vectorReader.getFloatVectorValues(field.name), vectors -> { + iterateDocValues(reader.maxDoc(), () -> vectorReader.getFloatVectorValues(field.name).iterator(), vectors -> { cancellationChecker.logEvent(); - vectors.vectorValue(); + vectors.index(); }); // do a couple of randomized searches to figure out min and max offsets of index file FloatVectorValues vectorValues = vectorReader.getFloatVectorValues(field.name); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); final KnnCollector collector = new TopKnnCollector( Math.max(1, Math.min(100, vectorValues.size() - 1)), Integer.MAX_VALUE @@ -581,11 +584,11 @@ void analyzeKnnVectors(SegmentReader reader, IndexDiskUsageStats stats) throws I int numDocsToVisit = reader.maxDoc() < 10 ? reader.maxDoc() : 10 * (int) Math.log10(reader.maxDoc()); int skipFactor = Math.max(reader.maxDoc() / numDocsToVisit, 1); for (int i = 0; i < reader.maxDoc(); i += skipFactor) { - if ((i = vectorValues.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) { + if ((i = iterator.advance(i)) == DocIdSetIterator.NO_MORE_DOCS) { break; } cancellationChecker.checkForCancellation(); - vectorReader.search(field.name, vectorValues.vectorValue(), collector, null); + vectorReader.search(field.name, vectorValues.vectorValue(iterator.index()), collector, null); } stats.addKnnVectors(field.name, directory.getBytesRead()); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java index 4313aa40cf13e..e78fc22f3215f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java @@ -22,18 +22,17 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; import org.elasticsearch.simdvec.VectorScorerFactory; import org.elasticsearch.simdvec.VectorSimilarityType; @@ -243,9 +242,9 @@ public String toString() { } @Override - public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, RandomAccessVectorValues values) + public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, KnnVectorValues values) throws IOException { - if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) { + if (values instanceof QuantizedByteVectorValues qValues && qValues.getSlice() != null) { // TODO: optimize int4 quantization if (qValues.getScalarQuantizer().getBits() != 7) { return delegate.getRandomVectorScorerSupplier(sim, values); @@ -253,7 +252,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarity if (factory != null) { var scorer = factory.getInt7SQVectorScorerSupplier( VectorSimilarityType.of(sim), - values.getSlice(), + qValues.getSlice(), qValues, qValues.getScalarQuantizer().getConstantMultiplier() ); @@ -266,9 +265,9 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarity } @Override - public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, RandomAccessVectorValues values, float[] query) + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, float[] query) throws IOException { - if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) { + if (values instanceof QuantizedByteVectorValues qValues && qValues.getSlice() != null) { // TODO: optimize int4 quantization if (qValues.getScalarQuantizer().getBits() != 7) { return delegate.getRandomVectorScorer(sim, values, query); @@ -284,7 +283,7 @@ public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, Ra } @Override - public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, RandomAccessVectorValues values, byte[] query) + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, byte[] query) throws IOException { return delegate.getRandomVectorScorer(sim, values, query); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index f1ae4e3fdeded..29e179dfc7c5d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -14,13 +14,14 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.elasticsearch.script.field.vectors.ESVectorUtil; import java.io.IOException; @@ -61,14 +62,14 @@ public String toString() { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues + KnnVectorValues vectorValues ) throws IOException { - assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + assert vectorValues instanceof ByteVectorValues; assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN; - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) { - assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false; + if (vectorValues instanceof ByteVectorValues byteVectorValues) { + assert byteVectorValues instanceof QuantizedByteVectorValues == false; return switch (vectorSimilarityFunction) { - case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes); + case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(byteVectorValues); }; } throw new IllegalArgumentException("Unsupported vector type or similarity function"); @@ -77,18 +78,15 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, - byte[] bytes - ) { - assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + assert vectorValues instanceof ByteVectorValues; assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN; - if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) { - checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension()); + if (vectorValues instanceof ByteVectorValues byteVectorValues) { + checkDimensions(target.length, byteVectorValues.dimension()); return switch (vectorSimilarityFunction) { - case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer( - randomAccessVectorValuesBytes, - bytes - ); + case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer(byteVectorValues, target); }; } throw new IllegalArgumentException("Unsupported vector type or similarity function"); @@ -96,10 +94,10 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessVectorValues randomAccessVectorValues, - float[] floats - ) { + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { throw new IllegalArgumentException("Unsupported vector type"); } } @@ -110,9 +108,9 @@ static float hammingScore(byte[] a, byte[] b) { static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { private final byte[] query; - private final RandomAccessVectorValues.Bytes byteValues; + private final ByteVectorValues byteValues; - HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) { + HammingVectorScorer(ByteVectorValues byteValues, byte[] query) { super(byteValues); this.query = query; this.byteValues = byteValues; @@ -125,9 +123,9 @@ public float score(int i) throws IOException { } static class HammingScorerSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2; + private final ByteVectorValues byteValues, byteValues1, byteValues2; - HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException { + HammingScorerSupplier(ByteVectorValues byteValues) throws IOException { this.byteValues = byteValues; this.byteValues1 = byteValues.copy(); this.byteValues2 = byteValues.copy(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java index e8da3b72ae7c7..04069333deb13 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java @@ -45,24 +45,13 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - // Lazy load vectors as we may iterate but not actually require the vector - return vectorValue(in.docID()); + public DocIndexIterator iterator() { + return in.iterator(); } @Override - public int docID() { - return in.docID(); - } - - @Override - public int nextDoc() throws IOException { - return in.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - return in.advance(target); + public FloatVectorValues copy() throws IOException { + return in.copy(); } @Override @@ -74,22 +63,24 @@ public float magnitude() { return magnitude; } - private float[] vectorValue(int docId) throws IOException { + @Override + public float[] vectorValue(int ord) throws IOException { + int docId = ordToDoc(ord); if (docId != this.docId) { this.docId = docId; hasMagnitude = decodedMagnitude(docId); // We should only copy and transform if we have a stored a non-unit length magnitude if (hasMagnitude) { - System.arraycopy(in.vectorValue(), 0, vector, 0, dimension()); + System.arraycopy(in.vectorValue(ord), 0, vector, 0, dimension()); for (int i = 0; i < vector.length; i++) { vector[i] *= magnitude; } return vector; } else { - return in.vectorValue(); + return in.vectorValue(ord); } } else { - return hasMagnitude ? vector : in.vectorValue(); + return hasMagnitude ? vector : in.vectorValue(ord); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4adfe619ca4e1..a48af90d539e6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SegmentReadState; @@ -2167,6 +2168,7 @@ private class IndexedSyntheticFieldLoader extends SourceLoader.DocValuesBasedSyn private ByteVectorValues byteVectorValues; private boolean hasValue; private boolean hasMagnitude; + private int ord; private final IndexVersion indexCreatedVersion; private final VectorSimilarity vectorSimilarity; @@ -2184,16 +2186,20 @@ public DocValuesLoader docValuesLoader(LeafReader leafReader, int[] docIdsInLeaf if (indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(vectorSimilarity)) { magnitudeReader = leafReader.getNumericDocValues(fullPath() + COSINE_MAGNITUDE_FIELD_SUFFIX); } + KnnVectorValues.DocIndexIterator iterator = values.iterator(); return docId -> { - hasValue = docId == values.advance(docId); + hasValue = docId == iterator.advance(docId); hasMagnitude = hasValue && magnitudeReader != null && magnitudeReader.advanceExact(docId); + ord = iterator.index(); return hasValue; }; } byteVectorValues = leafReader.getByteVectorValues(fullPath()); if (byteVectorValues != null) { + KnnVectorValues.DocIndexIterator iterator = byteVectorValues.iterator(); return docId -> { - hasValue = docId == byteVectorValues.advance(docId); + hasValue = docId == iterator.advance(docId); + ord = iterator.index(); return hasValue; }; } @@ -2216,7 +2222,7 @@ public void write(XContentBuilder b) throws IOException { } b.startArray(leafName()); if (values != null) { - for (float v : values.vectorValue()) { + for (float v : values.vectorValue(ord)) { if (hasMagnitude) { b.value(v * magnitude); } else { @@ -2224,7 +2230,7 @@ public void write(XContentBuilder b) throws IOException { } } } else if (byteVectorValues != null) { - byte[] vectorValue = byteVectorValues.vectorValue(); + byte[] vectorValue = byteVectorValues.vectorValue(ord); for (byte value : vectorValue) { b.value(value); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java index fd7c5227e22ac..be1b972dcd41a 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java @@ -10,6 +10,7 @@ package org.elasticsearch.script.field.vectors; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues; @@ -19,7 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField { - protected ByteVectorValues input; // null if no vectors + protected final ByteVectorValues input; // null if no vectors + protected final KnnVectorValues.DocIndexIterator iterator; // null if no vectors protected byte[] vector; protected final int dims; @@ -31,6 +33,7 @@ protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, Str super(name, elementType); this.dims = dims; this.input = input; + this.iterator = input == null ? null : input.iterator(); } @Override @@ -38,15 +41,15 @@ public void setNextDocId(int docId) throws IOException { if (input == null) { return; } - int currentDoc = input.docID(); + int currentDoc = iterator.docID(); if (currentDoc == NO_MORE_DOCS || docId < currentDoc) { vector = null; } else if (docId == currentDoc) { - vector = input.vectorValue(); + vector = input.vectorValue(iterator.index()); } else { - currentDoc = input.advance(docId); + currentDoc = iterator.advance(docId); if (currentDoc == docId) { - vector = input.vectorValue(); + vector = input.vectorValue(iterator.index()); } else { vector = null; } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/KnnDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/KnnDenseVectorDocValuesField.java index c7678b03dd8c5..3e38092200511 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/KnnDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/KnnDenseVectorDocValuesField.java @@ -10,6 +10,7 @@ package org.elasticsearch.script.field.vectors; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenormalizedCosineFloatVectorValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; @@ -20,7 +21,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; public class KnnDenseVectorDocValuesField extends DenseVectorDocValuesField { - protected FloatVectorValues input; // null if no vectors + protected final FloatVectorValues input; // null if no vectors + protected final KnnVectorValues.DocIndexIterator iterator; protected float[] vector; protected final int dims; @@ -28,6 +30,7 @@ public KnnDenseVectorDocValuesField(@Nullable FloatVectorValues input, String na super(name, ElementType.FLOAT); this.dims = dims; this.input = input; + this.iterator = input == null ? null : input.iterator(); } @Override @@ -35,15 +38,15 @@ public void setNextDocId(int docId) throws IOException { if (input == null) { return; } - int currentDoc = input.docID(); + int currentDoc = iterator.docID(); if (currentDoc == NO_MORE_DOCS || docId < currentDoc) { vector = null; } else if (docId == currentDoc) { - vector = input.vectorValue(); + vector = input.vectorValue(iterator.index()); } else { - currentDoc = input.advance(docId); + currentDoc = iterator.advance(docId); if (currentDoc == docId) { - vector = input.vectorValue(); + vector = input.vectorValue(iterator.index()); } else { vector = null; } diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java index ad0979cb3a481..64b54d3623f04 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -16,6 +16,7 @@ import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.PointValues; import org.apache.lucene.index.QueryTimeout; @@ -459,7 +460,6 @@ public void grow(int count) { } private static class ExitableByteVectorValues extends ByteVectorValues { - private int calls; private final QueryCancellation queryCancellation; private final ByteVectorValues in; @@ -479,8 +479,13 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { - return in.vectorValue(); + public byte[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); + } + + @Override + public int ordToDoc(int ord) { + return in.ordToDoc(ord); } @Override @@ -505,33 +510,17 @@ public DocIdSetIterator iterator() { } @Override - public int docID() { - return in.docID(); + public DocIndexIterator iterator() { + return createExitableIterator(in.iterator(), queryCancellation); } @Override - public int nextDoc() throws IOException { - final int nextDoc = in.nextDoc(); - checkAndThrowWithSampling(); - return nextDoc; - } - - @Override - public int advance(int target) throws IOException { - final int advance = in.advance(target); - checkAndThrowWithSampling(); - return advance; - } - - private void checkAndThrowWithSampling() { - if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) { - this.queryCancellation.checkCancelled(); - } + public ByteVectorValues copy() throws IOException { + return in.copy(); } } private static class ExitableFloatVectorValues extends FilterFloatVectorValues { - private int calls; private final QueryCancellation queryCancellation; ExitableFloatVectorValues(FloatVectorValues vectorValues, QueryCancellation queryCancellation) { @@ -541,17 +530,13 @@ private static class ExitableFloatVectorValues extends FilterFloatVectorValues { } @Override - public int advance(int target) throws IOException { - final int advance = super.advance(target); - checkAndThrowWithSampling(); - return advance; + public float[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); } @Override - public int nextDoc() throws IOException { - final int nextDoc = super.nextDoc(); - checkAndThrowWithSampling(); - return nextDoc; + public int ordToDoc(int ord) { + return in.ordToDoc(ord); } @Override @@ -575,13 +560,61 @@ public DocIdSetIterator iterator() { }; } - private void checkAndThrowWithSampling() { - if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) { - this.queryCancellation.checkCancelled(); - } + @Override + public DocIndexIterator iterator() { + return createExitableIterator(in.iterator(), queryCancellation); + } + + @Override + public FloatVectorValues copy() throws IOException { + return in.copy(); } } + private static KnnVectorValues.DocIndexIterator createExitableIterator( + KnnVectorValues.DocIndexIterator delegate, + QueryCancellation queryCancellation + ) { + return new KnnVectorValues.DocIndexIterator() { + private int calls; + + @Override + public int index() { + return delegate.index(); + } + + @Override + public int docID() { + return delegate.docID(); + } + + @Override + public long cost() { + return delegate.cost(); + } + + @Override + public int nextDoc() throws IOException { + int nextDoc = delegate.nextDoc(); + checkAndThrowWithSampling(); + return nextDoc; + } + + @Override + public int advance(int target) throws IOException { + final int advance = delegate.advance(target); + checkAndThrowWithSampling(); + return advance; + } + + private void checkAndThrowWithSampling() { + if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) { + queryCancellation.checkCancelled(); + } + } + }; + } + private static class ExitableDocSetIterator extends DocIdSetIterator { private int calls; private final DocIdSetIterator in; @@ -636,18 +669,18 @@ protected FilterFloatVectorValues(FloatVectorValues in) { } @Override - public int docID() { - return in.docID(); + public DocIndexIterator iterator() { + return in.iterator(); } @Override - public int nextDoc() throws IOException { - return in.nextDoc(); + public float[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); } @Override - public int advance(int target) throws IOException { - return in.advance(target); + public FloatVectorValues copy() throws IOException { + return in.copy(); } @Override @@ -660,9 +693,5 @@ public int size() { return in.size(); } - @Override - public float[] vectorValue() throws IOException { - return in.vectorValue(); - } } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java index 8f0a306e1eb3b..86b60d9984de5 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java @@ -19,6 +19,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; @@ -110,8 +111,9 @@ public void testRandom() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] v = vectorValues.vectorValue(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while ((docId = iterator.nextDoc()) != NO_MORE_DOCS) { + byte[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java index cee60efb57327..f89b481a13fd8 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java @@ -19,6 +19,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorSimilarityFunction; @@ -68,9 +69,10 @@ public void testAddIndexesDirectory0FS() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); + assertEquals(0, vectorValues.vectorValue(iterator.index())[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -110,12 +112,13 @@ private void testAddIndexesDirectory01FS(VectorSimilarityFunction similarityFunc try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); // The merge order is randomized, we might get 1 first, or 2 - float value = vectorValues.vectorValue()[0]; + float value = vectorValues.vectorValue(iterator.index())[0]; assertTrue(value == 1 || value == 2); - assertEquals(1, vectorValues.nextDoc()); - value += vectorValues.vectorValue()[0]; + assertEquals(1, iterator.nextDoc()); + value += vectorValues.vectorValue(iterator.index())[0]; assertEquals(3f, value, 0); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValuesTests.java index b2ffb779be00b..de4ab0bc5df30 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValuesTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValuesTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.mapper.vectors; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.NumericDocValues; import org.elasticsearch.test.ESTestCase; @@ -25,7 +26,7 @@ public void testEmptyVectors() throws IOException { wrap(new float[0][0]), wrapMagnitudes(new float[0]) ); - assertEquals(NO_MORE_DOCS, normalizedCosineFloatVectorValues.nextDoc()); + assertEquals(NO_MORE_DOCS, normalizedCosineFloatVectorValues.iterator().nextDoc()); } public void testRandomVectors() throws IOException { @@ -47,9 +48,10 @@ public void testRandomVectors() throws IOException { wrapMagnitudes(magnitudes) ); + KnnVectorValues.DocIndexIterator iterator = normalizedCosineFloatVectorValues.iterator(); for (int i = 0; i < numVectors; i++) { - assertEquals(i, normalizedCosineFloatVectorValues.advance(i)); - assertArrayEquals(vectors[i], normalizedCosineFloatVectorValues.vectorValue(), (float) 1e-6); + assertEquals(i, iterator.advance(i)); + assertArrayEquals(vectors[i], normalizedCosineFloatVectorValues.vectorValue(iterator.index()), (float) 1e-6); assertEquals(magnitudes[i], normalizedCosineFloatVectorValues.magnitude(), (float) 1e-6); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/KnnDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/KnnDenseVectorScriptDocValuesTests.java index c007156c806eb..baade683a90fd 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/KnnDenseVectorScriptDocValuesTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/KnnDenseVectorScriptDocValuesTests.java @@ -208,7 +208,41 @@ public int size() { } @Override - public byte[] vectorValue() { + public DocIndexIterator iterator() { + return new DocIndexIterator() { + @Override + public int index() { + return index; + } + + @Override + public int docID() { + return index; + } + + @Override + public int nextDoc() { + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) { + if (target >= size()) { + return NO_MORE_DOCS; + } + return index = target; + } + + @Override + public long cost() { + return 0; + } + }; + } + + @Override + public byte[] vectorValue(int ord) { + assert ord == index; for (int i = 0; i < byteVector.length; i++) { byteVector[i] = (byte) vectors[index][i]; } @@ -216,25 +250,12 @@ public byte[] vectorValue() { } @Override - public int docID() { - return index; - } - - @Override - public int nextDoc() { + public ByteVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public int advance(int target) { - if (target >= size()) { - return NO_MORE_DOCS; - } - return index = target; - } - - @Override - public VectorScorer scorer(byte[] floats) throws IOException { + public VectorScorer scorer(byte[] floats) { throw new UnsupportedOperationException(); } }; @@ -256,30 +277,51 @@ public int size() { } @Override - public float[] vectorValue() { - return vectors[index]; - } - - @Override - public int docID() { - return index; + public DocIndexIterator iterator() { + return new DocIndexIterator() { + @Override + public int index() { + return index; + } + + @Override + public int docID() { + return index; + } + + @Override + public int nextDoc() throws IOException { + return advance(index + 1); + } + + @Override + public int advance(int target) throws IOException { + if (target >= size()) { + return NO_MORE_DOCS; + } + return index = target; + } + + @Override + public long cost() { + return 0; + } + }; } @Override - public int nextDoc() { - return advance(index + 1); + public float[] vectorValue(int ord) { + assert ord == index; + return vectors[index]; } @Override - public int advance(int target) { - if (target >= size()) { - return NO_MORE_DOCS; - } - return index = target; + public FloatVectorValues copy() { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] floats) throws IOException { + public VectorScorer scorer(float[] floats) { throw new UnsupportedOperationException(); } }; diff --git a/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java b/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java index fff5dcb4bb80b..f3357a72c9243 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java @@ -15,6 +15,7 @@ import org.apache.lucene.document.StringField; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.PointValues; @@ -205,15 +206,17 @@ public void testExitableDirectoryReaderVectors() throws IOException { cancelled.set(false); // Avoid exception during construction of the wrapper objects FloatVectorValues vectorValues = searcher.getIndexReader().leaves().get(0).reader().getFloatVectorValues(KNN_FIELD_NAME); cancelled.set(true); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); // On the first doc when already canceled, it throws - expectThrows(TaskCancelledException.class, vectorValues::nextDoc); + expectThrows(TaskCancelledException.class, iterator::nextDoc); cancelled.set(false); // Avoid exception during construction of the wrapper objects FloatVectorValues uncancelledVectorValues = searcher.getIndexReader().leaves().get(0).reader().getFloatVectorValues(KNN_FIELD_NAME); + uncancelledVectorValues.iterator(); cancelled.set(true); searcher.removeQueryCancellation(cancellation); // On the first doc when already canceled, it throws, but with the cancellation removed, it should not - uncancelledVectorValues.nextDoc(); + iterator.nextDoc(); } private static class PointValuesIntersectVisitor implements PointValues.IntersectVisitor { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/accesscontrol/FieldSubsetReaderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/accesscontrol/FieldSubsetReaderTests.java index dbabc891cec6e..db250b16eab16 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/accesscontrol/FieldSubsetReaderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/accesscontrol/FieldSubsetReaderTests.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NoMergePolicy; @@ -205,8 +206,9 @@ public void testKnnVectors() throws Exception { FloatVectorValues vectorValues = leafReader.getFloatVectorValues("fieldA"); assertEquals(3, vectorValues.dimension()); assertEquals(1, vectorValues.size()); - assertEquals(0, vectorValues.nextDoc()); - assertNotNull(vectorValues.vectorValue()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); + assertNotNull(vectorValues.vectorValue(iterator.index())); TopDocs topDocs = leafReader.searchNearestVectors("fieldA", new float[] { 1.0f, 1.0f, 1.0f }, 5, null, Integer.MAX_VALUE); assertNotNull(topDocs); @@ -239,8 +241,9 @@ public void testKnnByteVectors() throws Exception { ByteVectorValues vectorValues = leafReader.getByteVectorValues("fieldA"); assertEquals(3, vectorValues.dimension()); assertEquals(1, vectorValues.size()); - assertEquals(0, vectorValues.nextDoc()); - assertNotNull(vectorValues.vectorValue()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); + assertNotNull(vectorValues.vectorValue(iterator.index())); TopDocs topDocs = leafReader.searchNearestVectors("fieldA", new byte[] { 1, 1, 1 }, 5, null, Integer.MAX_VALUE); assertNotNull(topDocs);