diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 7be07a7659f66..d399c826e0bf2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; @@ -33,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; @@ -56,10 +58,12 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -271,6 +275,7 @@ private static void checkAggregate(LogicalPlan p, Set failures) { r -> failures.add(fail(r, "the rate aggregate[{}] can only be used within the metrics command", r.sourceText())) ); } + checkCategorizeGrouping(agg, failures); } else { p.forEachExpression( GroupingFunction.class, @@ -279,6 +284,74 @@ private static void checkAggregate(LogicalPlan p, Set failures) { } } + /** + * Check CATEGORIZE grouping function usages. + *

+ * Some of those checks are temporary, until the required syntax or engine changes are implemented. + *

+ */ + private static void checkCategorizeGrouping(Aggregate agg, Set failures) { + // Forbid CATEGORIZE grouping function with other groupings + if (agg.groupings().size() > 1) { + agg.groupings().forEach(g -> { + g.forEachDown( + Categorize.class, + categorize -> failures.add( + fail(categorize, "cannot use CATEGORIZE grouping function [{}] with multiple groupings", categorize.sourceText()) + ) + ); + }); + } + + // Forbid CATEGORIZE grouping functions not being top level groupings + agg.groupings().forEach(g -> { + // Check all CATEGORIZE but the top level one + Alias.unwrap(g) + .children() + .forEach( + child -> child.forEachDown( + Categorize.class, + c -> failures.add( + fail(c, "CATEGORIZE grouping function [{}] can't be used within other expressions", c.sourceText()) + ) + ) + ); + }); + + // Forbid CATEGORIZE being used in the aggregations + agg.aggregates().forEach(a -> { + a.forEachDown( + Categorize.class, + categorize -> failures.add( + fail(categorize, "cannot use CATEGORIZE grouping function [{}] within the aggregations", categorize.sourceText()) + ) + ); + }); + + // Forbid CATEGORIZE being referenced in the aggregation functions + Map categorizeByAliasId = new HashMap<>(); + agg.groupings().forEach(g -> { + g.forEachDown(Alias.class, alias -> { + if (alias.child() instanceof Categorize categorize) { + categorizeByAliasId.put(alias.id(), categorize); + } + }); + }); + agg.aggregates() + .forEach(a -> a.forEachDown(AggregateFunction.class, aggregate -> aggregate.forEachDown(Attribute.class, attribute -> { + var categorize = categorizeByAliasId.get(attribute.id()); + if (categorize != null) { + failures.add( + fail( + attribute, + "cannot reference CATEGORIZE grouping function [{}] within the aggregations", + attribute.sourceText() + ) + ); + } + }))); + } + private static void checkRateAggregates(Expression expr, int nestedLevel, Set failures) { if (expr instanceof AggregateFunction) { nestedLevel++; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 0e0c2de11fac3..8b364a603405c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1737,6 +1737,68 @@ public void testIntervalAsString() { ); } + public void testCategorizeSingleGrouping() { + query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)"); + query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); + + assertEquals( + "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", + error("from test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no") + ); + assertEquals( + "1:39: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", + error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)") + ); + assertEquals( + "1:35: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", + error("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no") + ); + assertEquals( + "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings\n" + + "line 1:55: cannot use CATEGORIZE grouping function [CATEGORIZE(last_name)] with multiple groupings", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)") + ); + assertEquals( + "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)") + ); + } + + public void testCategorizeNestedGrouping() { + query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)"); + + assertEquals( + "1:40: CATEGORIZE grouping function [CATEGORIZE(first_name)] can't be used within other expressions", + error("FROM test | STATS COUNT(*) BY MV_COUNT(CATEGORIZE(first_name))") + ); + assertEquals( + "1:31: CATEGORIZE grouping function [CATEGORIZE(first_name)] can't be used within other expressions", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name)::datetime") + ); + } + + public void testCategorizeWithinAggregations() { + query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)"); + + assertEquals( + "1:25: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] within the aggregations", + error("FROM test | STATS COUNT(CATEGORIZE(first_name)) BY CATEGORIZE(first_name)") + ); + + assertEquals( + "1:25: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + error("FROM test | STATS COUNT(cat) BY cat = CATEGORIZE(first_name)") + ); + assertEquals( + "1:30: cannot reference CATEGORIZE grouping function [cat] within the aggregations", + error("FROM test | STATS SUM(LENGTH(cat::keyword) + LENGTH(last_name)) BY cat = CATEGORIZE(first_name)") + ); + assertEquals( + "1:25: cannot reference CATEGORIZE grouping function [`CATEGORIZE(first_name)`] within the aggregations", + error("FROM test | STATS COUNT(`CATEGORIZE(first_name)`) BY CATEGORIZE(first_name)") + ); + } + private void query(String query) { defaultAnalyzer.analyze(parser.createStatement(query)); }