Skip to content

Commit

Permalink
ESQL: More tests for filtered aggs
Browse files Browse the repository at this point in the history
This adds a test to *every* agg for when it's entirely filtered away and
another when filtering is enabled but unused. I'll follow up with
another test later for partial filtering.
  • Loading branch information
nik9000 committed Sep 30, 2024
1 parent 31bea56 commit a60e961
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ private void addRawInput(IntVector groups) {
*/
private void addRawInput(IntBlock groups) {
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
// TODO remove the check one we don't emit null anymore
if (groups.isNull(groupPosition)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute;

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.EvalOperator;

public record ConstantBooleanExpressionEvaluator(BlockFactory factory, boolean value) implements EvalOperator.ExpressionEvaluator {
public static EvalOperator.ExpressionEvaluator.Factory factory(boolean value) {
return ctx -> new ConstantBooleanExpressionEvaluator(ctx.blockFactory(), value);
}

@Override
public Block eval(Page page) {
return factory.newConstantBooleanVector(value, page.getPositionCount()).asBlock();
}

@Override
public void close() {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.compute.aggregation;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.ConstantBooleanExpressionEvaluator;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockTestUtils;
Expand All @@ -34,6 +35,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
Expand All @@ -58,8 +60,17 @@ protected final int aggregatorIntermediateBlockCount() {

@Override
protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) {
return simpleWithMode(mode, Function.identity());
}

private Operator.OperatorFactory simpleWithMode(
AggregatorMode mode,
Function<AggregatorFunctionSupplier, AggregatorFunctionSupplier> wrap
) {
List<Integer> channels = mode.isInputPartial() ? range(0, aggregatorIntermediateBlockCount()).boxed().toList() : List.of(0);
return new AggregationOperator.AggregationOperatorFactory(List.of(aggregatorFunction(channels).aggregatorFactory(mode)), mode);
AggregatorFunctionSupplier supplier = aggregatorFunction(channels);
Aggregator.Factory factory = wrap.apply(supplier).aggregatorFactory(mode);
return new AggregationOperator.AggregationOperatorFactory(List.of(factory), mode);
}

@Override
Expand Down Expand Up @@ -141,6 +152,7 @@ public final void testEmptyInput() {
List<Page> results = drive(simple().get(driverContext), List.<Page>of().iterator(), driverContext);

assertThat(results, hasSize(1));
assertOutputFromEmpty(results.get(0).getBlock(0));
}

public final void testEmptyInputInitialFinal() {
Expand All @@ -166,6 +178,31 @@ public final void testEmptyInputInitialIntermediateFinal() {
assertOutputFromEmpty(results.get(0).getBlock(0));
}

public void testAllFiltered() {
Operator.OperatorFactory factory = simpleWithMode(
AggregatorMode.SINGLE,
agg -> new FilteredAggregatorFunctionSupplier(agg, ConstantBooleanExpressionEvaluator.factory(false))
);
DriverContext driverContext = driverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(driverContext.blockFactory(), 10));
List<Page> results = drive(factory.get(driverContext), input.iterator(), driverContext);
assertThat(results, hasSize(1));
assertOutputFromEmpty(results.get(0).getBlock(0));
}

public final void testNoneFiltered() {
Operator.OperatorFactory factory = simpleWithMode(
AggregatorMode.SINGLE,
agg -> new FilteredAggregatorFunctionSupplier(agg, ConstantBooleanExpressionEvaluator.factory(true))
);
DriverContext driverContext = driverContext();
List<Page> input = CannedSourceOperator.collectPages(simpleInput(driverContext.blockFactory(), 10));
List<Page> origInput = BlockTestUtils.deepCopyOf(input, TestBlockFactory.getNonBreakingInstance());
List<Page> results = drive(factory.get(driverContext), input.iterator(), driverContext);
assertThat(results, hasSize(1));
assertSimpleOutput(origInput, results);
}

// Returns an intermediate state that is equivalent to what the local execution planner will emit
// if it determines that certain shards have no relevant data.
List<Page> nullIntermediateState(BlockFactory blockFactory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongBooleanTupleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -53,4 +55,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongBytesRefTupleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -58,4 +60,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongDoubleTupleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -57,4 +59,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongFloatTupleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -57,4 +59,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongIntBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -57,4 +59,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.compute.operator.TupleBlockSourceOperator;
Expand Down Expand Up @@ -56,4 +58,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.LongDoubleTupleBlockSourceOperator;
import org.elasticsearch.compute.operator.SourceOperator;
Expand Down Expand Up @@ -58,4 +60,13 @@ protected void assertOutputFromNullOnly(Block b, int position) {
assertThat(b.getValueCount(position), equalTo(1));
assertThat(((LongBlock) b).getLong(b.getFirstValueIndex(position)), equalTo(0L));
}

@Override
protected void assertOutputFromAllFiltered(Block b) {
assertThat(b.elementType(), equalTo(ElementType.LONG));
LongVector v = (LongVector) b.asVector();
for (int p = 0; p < v.getPositionCount(); p++) {
assertThat(v.getLong(p), equalTo(0L));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,9 @@ public void checkUnclosed() {
}
assertThat(unclosed, empty());
}

@Override
public void testAllFiltered() {
assumeFalse("can't double filter. tests already filter.", true);
}
}
Loading

0 comments on commit a60e961

Please sign in to comment.