Skip to content

Commit

Permalink
ESQL: Add CATEGORIZE() check to avoid having multiple groupings (#116660
Browse files Browse the repository at this point in the history
) (#116821)

Added checks to avoid unsupported usages of `CATEGORIZE` grouping function:
- Can't be used with other groups
- Can't be used within other functions
- Can't be used or referenced in the aggregates side
  • Loading branch information
ivancea authored Nov 14, 2024
1 parent 8126bf5 commit 3e64817
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
Expand All @@ -33,6 +34,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
Expand All @@ -56,10 +58,12 @@
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
Expand Down Expand Up @@ -271,6 +275,7 @@ private static void checkAggregate(LogicalPlan p, Set<Failure> failures) {
r -> failures.add(fail(r, "the rate aggregate[{}] can only be used within the metrics command", r.sourceText()))
);
}
checkCategorizeGrouping(agg, failures);
} else {
p.forEachExpression(
GroupingFunction.class,
Expand All @@ -279,6 +284,74 @@ private static void checkAggregate(LogicalPlan p, Set<Failure> failures) {
}
}

/**
* Check CATEGORIZE grouping function usages.
* <p>
* Some of those checks are temporary, until the required syntax or engine changes are implemented.
* </p>
*/
private static void checkCategorizeGrouping(Aggregate agg, Set<Failure> failures) {
// Forbid CATEGORIZE grouping function with other groupings
if (agg.groupings().size() > 1) {
agg.groupings().forEach(g -> {
g.forEachDown(
Categorize.class,
categorize -> failures.add(
fail(categorize, "cannot use CATEGORIZE grouping function [{}] with multiple groupings", categorize.sourceText())
)
);
});
}

// Forbid CATEGORIZE grouping functions not being top level groupings
agg.groupings().forEach(g -> {
// Check all CATEGORIZE but the top level one
Alias.unwrap(g)
.children()
.forEach(
child -> child.forEachDown(
Categorize.class,
c -> failures.add(
fail(c, "CATEGORIZE grouping function [{}] can't be used within other expressions", c.sourceText())
)
)
);
});

// Forbid CATEGORIZE being used in the aggregations
agg.aggregates().forEach(a -> {
a.forEachDown(
Categorize.class,
categorize -> failures.add(
fail(categorize, "cannot use CATEGORIZE grouping function [{}] within the aggregations", categorize.sourceText())
)
);
});

// Forbid CATEGORIZE being referenced in the aggregation functions
Map<NameId, Categorize> categorizeByAliasId = new HashMap<>();
agg.groupings().forEach(g -> {
g.forEachDown(Alias.class, alias -> {
if (alias.child() instanceof Categorize categorize) {
categorizeByAliasId.put(alias.id(), categorize);
}
});
});
agg.aggregates()
.forEach(a -> a.forEachDown(AggregateFunction.class, aggregate -> aggregate.forEachDown(Attribute.class, attribute -> {
var categorize = categorizeByAliasId.get(attribute.id());
if (categorize != null) {
failures.add(
fail(
attribute,
"cannot reference CATEGORIZE grouping function [{}] within the aggregations",
attribute.sourceText()
)
);
}
})));
}

private static void checkRateAggregates(Expression expr, int nestedLevel, Set<Failure> failures) {
if (expr instanceof AggregateFunction) {
nestedLevel++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,68 @@ public void testIntervalAsString() {
);
}

public void testCategorizeSingleGrouping() {
query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");

assertEquals(
"1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
error("from test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no")
);
assertEquals(
"1:39: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)")
);
assertEquals(
"1:35: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
error("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no")
);
assertEquals(
"1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings\n"
+ "line 1:55: cannot use CATEGORIZE grouping function [CATEGORIZE(last_name)] with multiple groupings",
error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)")
);
assertEquals(
"1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)")
);
}

public void testCategorizeNestedGrouping() {
query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");

assertEquals(
"1:40: CATEGORIZE grouping function [CATEGORIZE(first_name)] can't be used within other expressions",
error("FROM test | STATS COUNT(*) BY MV_COUNT(CATEGORIZE(first_name))")
);
assertEquals(
"1:31: CATEGORIZE grouping function [CATEGORIZE(first_name)] can't be used within other expressions",
error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name)::datetime")
);
}

public void testCategorizeWithinAggregations() {
query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");

assertEquals(
"1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within the aggregations",
error("FROM test | STATS COUNT(CATEGORIZE(first_name)) BY CATEGORIZE(first_name)")
);

assertEquals(
"1:25: cannot reference CATEGORIZE grouping function [cat] within the aggregations",
error("FROM test | STATS COUNT(cat) BY cat = CATEGORIZE(first_name)")
);
assertEquals(
"1:30: cannot reference CATEGORIZE grouping function [cat] within the aggregations",
error("FROM test | STATS SUM(LENGTH(cat::keyword) + LENGTH(last_name)) BY cat = CATEGORIZE(first_name)")
);
assertEquals(
"1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within the aggregations",
error("FROM test | STATS COUNT(`CATEGORIZE(first_name)`) BY CATEGORIZE(first_name)")
);
}

private void query(String query) {
defaultAnalyzer.analyze(parser.createStatement(query));
}
Expand Down

0 comments on commit 3e64817

Please sign in to comment.