Skip to content

Commit

Permalink
[CALCITE-6586] Some Rules not firing due to RelMdPredicates returning…
Browse files Browse the repository at this point in the history
… null in VolcanoPlanner
  • Loading branch information
suibianwanwan committed Sep 24, 2024
1 parent 8d3cb82 commit c883774
Show file tree
Hide file tree
Showing 27 changed files with 231 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ protected CassandraSortRule(CassandraSortRuleConfig config) {

public RelNode convert(Sort sort, CassandraFilter filter) {
final RelTraitSet traitSet =
sort.getTraitSet().replace(CassandraRel.CONVENTION)
.replace(sort.getCollation());
sort.getTraitSet().replace(CassandraRel.CONVENTION);
return new CassandraSort(sort.getCluster(), traitSet,
convert(sort.getInput(), traitSet.replace(RelCollations.EMPTY)),
sort.getCollation());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public CassandraSort(RelOptCluster cluster, RelTraitSet traitSet,

@Override public Sort copy(RelTraitSet traitSet, RelNode input,
RelCollation newCollation, @Nullable RexNode offset, @Nullable RexNode fetch) {
return new CassandraSort(getCluster(), traitSet, input, collation);
return new CassandraSort(getCluster(), traitSet, input, newCollation);
}

@Override public void implement(Implementor implementor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@
*/
package org.apache.calcite.adapter.enumerable;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelDistributionTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.metadata.RelMdCollation;
import org.apache.calcite.rel.metadata.RelMdDistribution;
import org.apache.calcite.rel.metadata.RelMetadataQuery;

import org.immutables.value.Value;

Expand All @@ -46,19 +53,28 @@ protected EnumerableLimitRule(Config config) {

@Override public void onMatch(RelOptRuleCall call) {
final Sort sort = call.rel(0);
final RelOptCluster cluster = sort.getCluster();
final RelMetadataQuery mq = cluster.getMetadataQuery();

if (sort.offset == null && sort.fetch == null) {
return;
}
RelNode input = sort.getInput();
if (!sort.getCollation().getFieldCollations().isEmpty()) {
// Create a sort with the same sort key, but no offset or fetch.
input =
sort.copy(sort.getTraitSet(), input, sort.getCollation(), null, null);
}

final RelNode input = sort.getCollation().getFieldCollations().isEmpty()
? sort.getInput()
: sort.copy(sort.getTraitSet(), sort.getInput(), sort.getCollation(), null, null);
final RelTraitSet traitSet =
cluster.traitSetOf(EnumerableConvention.INSTANCE)
.replaceIfs(RelCollationTraitDef.INSTANCE,
() -> RelMdCollation.limit(mq, sort))
.replaceIf(RelDistributionTraitDef.INSTANCE,
() -> RelMdDistribution.limit(mq, input));
call.transformTo(
EnumerableLimit.create(
new EnumerableLimit(
cluster,
traitSet,
convert(call.getPlanner(), input,
input.getTraitSet().replace(EnumerableConvention.INSTANCE)),
input.getTraitSet().replace(EnumerableConvention.INSTANCE)),
sort.offset,
sort.fetch));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public EnumerableLimitSortRule(Config config) {
final Sort sort = call.rel(0);
RelNode input = sort.getInput();
final Sort o =
EnumerableLimitSort.create(
new EnumerableLimitSort(
sort.getCluster(),
sort.getTraitSet().replace(EnumerableConvention.INSTANCE),
convert(call.getPlanner(), input,
input.getTraitSet().replace(EnumerableConvention.INSTANCE)),
sort.getCollation(), sort.offset, sort.fetch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ public static EnumerableSort create(RelNode child, RelCollation collation,
fetch);
}

/** Creates an EnumerableSort. */
public static EnumerableSort create(RelNode child, RelCollation collation,
RelTraitSet traitSet, @Nullable RexNode offset, @Nullable RexNode fetch) {
final RelOptCluster cluster = child.getCluster();
return new EnumerableSort(cluster, traitSet, child, collation, offset,
fetch);
}

@Override public EnumerableSort copy(
RelTraitSet traitSet,
RelNode newInput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected EnumerableSortRule(Config config) {
input,
input.getTraitSet().replace(EnumerableConvention.INSTANCE)),
sort.getCollation(),
sort.getTraitSet().replace(EnumerableConvention.INSTANCE),
null,
null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ static <T extends RelMultipleTrait> RelTrait of(RelTraitDef def,
} else if (traitList.size() == 1) {
return def.canonize(traitList.get(0));
} else {
// make sure traits in RelCompositeTrait is strictly ordered
final RelMultipleTrait[] traits =
traitList.toArray(new RelMultipleTrait[0]);
traitList.stream().sorted().toArray(RelMultipleTrait[]::new);
for (int i = 0; i < traits.length; i++) {
traits[i] = (T) def.canonize(traits[i]);
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/apache/calcite/plan/RelTraitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,10 @@ public boolean contains(RelTrait trait) {
public boolean containsIfApplicable(RelTrait trait) {
// Note that '==' is sufficient, because trait should be canonized.
final RelTrait trait1 = getTrait(trait.getTraitDef());
if (trait1 instanceof RelCompositeTrait) {
List<RelMultipleTrait> traitList = ((RelCompositeTrait) trait1).traitList();
return traitList.contains(trait);
}
return trait1 == null || trait1 == trait;
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/rel/core/Sort.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ protected Sort(
this.fetch = fetch;
this.hints = ImmutableList.copyOf(hints);

assert traits.containsIfApplicable(collation)
assert collation.getFieldCollations().isEmpty() || traits.containsIfApplicable(collation)
: "traits=" + traits + ", collation=" + collation;
assert !(fetch == null
&& offset == null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ public LogicalSort(RelInput input) {
public static LogicalSort create(RelNode input, RelCollation collation,
@Nullable RexNode offset, @Nullable RexNode fetch) {
RelOptCluster cluster = input.getCluster();
collation = RelCollationTraitDef.INSTANCE.canonize(collation);
RelTraitSet traitSet =
input.getTraitSet().replace(Convention.NONE).replace(collation);
final RelCollation canonize = RelCollationTraitDef.INSTANCE.canonize(collation);

RelTraitSet traitSet = input.getTraitSet().replace(Convention.NONE);
if (!canonize.getFieldCollations().isEmpty()) {
// Preserve input collation if only offset and fetch
traitSet = traitSet.replaceIf(RelCollationTraitDef.INSTANCE, () -> canonize);
}
return new LogicalSort(cluster, traitSet, input, collation, offset, fetch);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ private RelMdCollation() {}

public @Nullable ImmutableList<RelCollation> collations(Sort sort,
RelMetadataQuery mq) {
return copyOf(
RelMdCollation.sort(sort.getCollation()));
List<RelCollation> collations =
Util.first(sort.getTraitSet().getTraits(RelCollationTraitDef.INSTANCE),
RelMdCollation.sort(sort.getCollation()));
return copyOf(collations);
}

public @Nullable ImmutableList<RelCollation> collations(SortExchange sort,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,51 +456,23 @@ public Boolean areColumnsUnique(Values rel, RelMetadataQuery mq,
ImmutableBitSet columns, boolean ignoreNulls) {
columns = decorateWithConstantColumnsFromPredicates(columns, rel, mq);
for (RelNode rel2 : rel.getRels()) {
if (rel2 instanceof Aggregate
|| rel2 instanceof Filter
|| rel2 instanceof Values
|| rel2 instanceof Sort
|| rel2 instanceof TableScan
|| simplyProjects(rel2, columns)) {
try {
final Boolean unique = mq.areColumnsUnique(rel2, columns, ignoreNulls);
if (unique != null) {
if (unique) {
return true;
}
} else {
return null;
try {
final Boolean unique = mq.areColumnsUnique(rel2, columns, ignoreNulls);
if (unique != null) {
if (unique) {
return true;
}
} catch (CyclicMetadataException e) {
// Ignore this relational expression; there will be non-cyclic ones
// in this set.
} else {
return null;
}
} catch (CyclicMetadataException e) {
// Ignore this relational expression; there will be non-cyclic ones
// in this set.
}
}
return false;
}

private static boolean simplyProjects(RelNode rel, ImmutableBitSet columns) {
if (!(rel instanceof Project)) {
return false;
}
Project project = (Project) rel;
final List<RexNode> projects = project.getProjects();
for (int column : columns) {
if (column >= projects.size()) {
return false;
}
if (!(projects.get(column) instanceof RexInputRef)) {
return false;
}
final RexInputRef ref = (RexInputRef) projects.get(column);
if (ref.getIndex() != column) {
return false;
}
}
return true;
}

/** Splits a column set between left and right sets. */
private static Pair<ImmutableBitSet, ImmutableBitSet>
splitLeftAndRightColumns(int leftCount, final ImmutableBitSet columns) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,9 @@ public RelOptPredicateList getPredicates(Values values, RelMetadataQuery mq) {
public RelOptPredicateList getPredicates(RelSubset r,
RelMetadataQuery mq) {
if (!Bug.CALCITE_1048_FIXED) {
return RelOptPredicateList.EMPTY;
// FIXME: This is a short-term fix and may disable some applicable rules.
// A complete solution will come with [CALCITE-1048].
return mq.getPulledUpPredicates(r.stripped());
}
final RexBuilder rexBuilder = r.getCluster().getRexBuilder();
RelOptPredicateList list = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.ConventionTraitDef;
import org.apache.calcite.plan.RelOptPredicateList;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelCollations;
Expand All @@ -28,14 +30,18 @@
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;

import com.google.common.collect.ImmutableList;

import org.immutables.value.Value;

import java.util.List;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;

/**
* Planner rule that removes keys from a
* a {@link org.apache.calcite.rel.core.Sort} if those keys are known to be
* {@link org.apache.calcite.rel.core.Sort} if those keys are known to be
* constant, or removes the entire Sort if all keys are constant.
*
* <p>Requires {@link RelCollationTraitDef}.
Expand Down Expand Up @@ -73,15 +79,29 @@ protected SortRemoveConstantKeysRule(Config config) {

// No active collations. Remove the sort completely
if (collationsList.isEmpty() && sort.offset == null && sort.fetch == null) {
call.transformTo(input);
final RelTraitSet traits = sort.getInput().getTraitSet()
.replaceIfs(RelCollationTraitDef.INSTANCE,
() -> sort.getTraitSet().getTraits(RelCollationTraitDef.INSTANCE));

// We won't copy the RelTraitSet for every node in the RelSubset,
// so stripped is probably a good choice.
RelNode stripped = input.stripped();
call.transformTo(
convert(stripped.copy(traits, stripped.getInputs()),
traits.replaceIf(ConventionTraitDef.INSTANCE, sort::getConvention)));
call.getPlanner().prune(sort);
return;
}

final RelCollation collation = RelCollations.of(collationsList);
RelCollation sortCollation = sort.getTraitSet().getTrait(RelCollationTraitDef.INSTANCE);

final Sort result =
sort.copy(
sort.getTraitSet().replaceIf(RelCollationTraitDef.INSTANCE, () -> collation),
sort.getTraitSet().
replaceIfs(RelCollationTraitDef.INSTANCE,
() -> ImmutableList.of(collation,
requireNonNull(sortCollation, "sortCollation"))),
input,
collation);
call.transformTo(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public SortRemoveRule(RelBuilderFactory relBuilderFactory) {
// Don't remove sort if would also remove OFFSET or LIMIT.
return;
}

// Composite trait not support in change trait
if (!sort.getTraitSet().allSimple()) {
return;
}
// Express the "sortedness" requirement in terms of a collation trait and
// we can get rid of the sort. This allows us to use rels that just happen
// to be sorted but get the same effect.
Expand Down
13 changes: 7 additions & 6 deletions core/src/test/java/org/apache/calcite/test/JdbcTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3176,12 +3176,13 @@ void testInnerJoinValues(String format) {
.enable(CalciteAssert.DB != CalciteAssert.DatabaseInstance.ORACLE)
.explainContains(""
+ "EnumerableAggregate(group=[{0}], m0=[COUNT($1)])\n"
+ " EnumerableAggregate(group=[{1, 3}])\n"
+ " EnumerableHashJoin(condition=[=($0, $2)], joinType=[inner])\n"
+ " EnumerableCalc(expr#0..9=[{inputs}], expr#10=[CAST($t4):INTEGER], expr#11=[1997], expr#12=[=($t10, $t11)], time_id=[$t0], the_year=[$t4], $condition=[$t12])\n"
+ " EnumerableTableScan(table=[[foodmart2, time_by_day]])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], time_id=[$t1], unit_sales=[$t7])\n"
+ " EnumerableTableScan(table=[[foodmart2, sales_fact_1997]])")
+ " EnumerableCalc(expr#0=[{inputs}], expr#1=[1997:SMALLINT], expr#2=[CAST($t1):SMALLINT], c0=[$t2], unit_sales=[$t0])\n"
+ " EnumerableAggregate(group=[{1}])\n"
+ " EnumerableHashJoin(condition=[=($0, $2)], joinType=[semi])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], time_id=[$t1], unit_sales=[$t7])\n"
+ " EnumerableTableScan(table=[[foodmart2, sales_fact_1997]])\n"
+ " EnumerableCalc(expr#0..9=[{inputs}], expr#10=[CAST($t4):INTEGER], expr#11=[1997], expr#12=[=($t10, $t11)], time_id=[$t0], the_year=[$t4], $condition=[$t12])\n"
+ " EnumerableTableScan(table=[[foodmart2, time_by_day]])\n")
.returns("c0=1997; m0=6\n");
}

Expand Down
11 changes: 11 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,17 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}

@Test void testSortRemoveAllKeysConstantInVolcano() {
final String sql = "select count(*) as c\n"
+ "from sales.emp\n"
+ "where deptno = 10\n"
+ "group by deptno, sal\n"
+ "order by deptno desc nulls last";
sql(sql)
.withVolcanoPlanner(false)
.check();
}

@Test void testSortRemovalOneKeyConstant() {
final String sql = "select count(*) as c\n"
+ "from sales.emp\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15041,6 +15041,34 @@ LogicalProject(C=[$0])
LogicalProject(DEPTNO=[$7], SAL=[$5])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testSortRemoveAllKeysConstantInVolcano">
<Resource name="sql">
<![CDATA[select count(*) as c
from sales.emp
where deptno = 10
group by deptno, sal
order by deptno desc nulls last]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(C=[$0])
LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])
LogicalProject(C=[$2], DEPTNO=[$0])
LogicalAggregate(group=[{0, 1}], C=[COUNT()])
LogicalProject(DEPTNO=[$7], SAL=[$5])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
EnumerableProject(C=[$2])
EnumerableAggregate(group=[{5, 7}], C=[COUNT()])
EnumerableFilter(condition=[=($7, 10)])
EnumerableTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down
Loading

0 comments on commit c883774

Please sign in to comment.