Skip to content

Commit

Permalink
[CALCITE-6605] Lattice SQL supports complex column expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
YiwenWu committed Oct 2, 2024
1 parent a4a27e3 commit 907ab4c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 3 deletions.
23 changes: 20 additions & 3 deletions core/src/main/java/org/apache/calcite/materialize/Lattice.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.calcite.schema.impl.StarTable;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
Expand Down Expand Up @@ -88,6 +89,7 @@
import static com.google.common.base.Preconditions.checkArgument;

import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static org.apache.calcite.rel.rel2sql.SqlImplementor.POS;

import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -282,9 +284,7 @@ public String sql(ImmutableBitSet groupSet, boolean group,
final StringBuilder groupBuf = new StringBuilder("\nGROUP BY ");
int k = 0;
final Set<String> columnNames = new HashSet<>();
final SqlWriter w = createSqlWriter(dialect, buf, f -> {
throw new UnsupportedOperationException();
});
final SqlWriter w = createSqlWriter(dialect, buf, resolveField(dialect));
if (groupSet != null) {
for (int i : groupSet) {
if (k++ > 0) {
Expand Down Expand Up @@ -370,6 +370,23 @@ public String sql(ImmutableBitSet groupSet, boolean group,
return buf.toString();
}

/** Resolves a field index to a corresponding SqlNode based on the column type. */
private IntFunction<SqlNode> resolveField(SqlDialect dialect) {
final IntFunction<SqlNode>[] fieldFuncRef = new IntFunction[1];
fieldFuncRef[0] = f -> {
Column column = columns.get(f);
if (column instanceof BaseColumn) {
return new SqlIdentifier(ImmutableList.of(((BaseColumn) column).column), POS);
}
if (column instanceof DerivedColumn) {
return new SqlImplementor.SimpleContext(dialect, fieldFuncRef[0])
.toSql(null, ((DerivedColumn) column).e);
}
throw new UnsupportedOperationException();
};
return fieldFuncRef[0];
}

/** Creates a context to which SQL can be generated. */
public SqlWriter createSqlWriter(SqlDialect dialect, StringBuilder buf,
IntFunction<SqlNode> field) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
* limitations under the License.
*/
package org.apache.calcite.materialize;

import org.apache.calcite.materialize.Lattice.Measure;
import org.apache.calcite.prepare.PlannerImpl;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.schema.SchemaPlus;
Expand All @@ -35,6 +37,8 @@
import org.apache.calcite.tools.Planner;
import org.apache.calcite.tools.RelConversionException;
import org.apache.calcite.tools.ValidationException;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableBitSet.Builder;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -752,6 +756,97 @@ private void checkDerivedColumn(Lattice lattice, List<String> tables,
checkDerivedColumn(lattice, tables, derivedColumns, 3, "n11", false);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6605">[CALCITE-6605]
* Lattice SQL supports complex column expressions </a>. */
@Test void testExpressionLatticeSql() throws Exception {
final Tester t = new Tester().foodmart().withEvolve(true);
final String q0 = "select\n"
+ " \"num_children_at_home\" + 12 as \"n12\",\n"
+ " sum(\"num_children_at_home\") as \"n10\",\n"
+ " count(*) as c\n"
+ "from \"customer\"\n"
+ "group by \"num_children_at_home\" + 12";
t.addQuery(q0);
assertThat(t.s.latticeMap, aMapWithSize(1));
final Lattice lattice = Iterables.getOnlyElement(t.s.latticeMap.values());
final String l0 = "customer:[COUNT(), SUM(customer.num_children_at_home)]";
assertThat(Iterables.getOnlyElement(t.s.latticeMap.keySet()), is(l0));
ImmutableList<Measure> measures = lattice.defaultMeasures;
assert measures.size() == 2;
Builder groupSetBuilder = ImmutableBitSet.builder();
measures.forEach(measure -> groupSetBuilder.addAll(measure.argBitSet()));
ImmutableBitSet groupSet = groupSetBuilder.build();
String sql = "SELECT \"customer\".\"num_children_at_home\", COUNT(*) AS \"m0\", "
+ "SUM(\"customer\".\"num_children_at_home\") AS \"m1\"\n"
+ "FROM \"foodmart\".\"customer\" AS \"customer\"\n"
+ "GROUP BY \"customer\".\"num_children_at_home\"";
assertThat(lattice.sql(groupSet, true, measures),
is(sql));
}

/** Test case for measure field involving a complex column operation,
* for example sum("num_children_at_home" + 10). */
@Test void testExpressionLatticeSql2() throws Exception {
final Tester t = new Tester().foodmart().withEvolve(true);
final String q0 = "select\n"
+ " \"num_children_at_home\" + 12 as \"n12\",\n"
+ " sum(\"num_children_at_home\" + 10) as \"n10\",\n"
+ " sum(\"num_children_at_home\" + 11) as \"n11\",\n"
+ " count(*) as c\n"
+ "from \"customer\"\n"
+ "group by \"num_children_at_home\" + 12";
t.addQuery(q0);
assertThat(t.s.latticeMap, aMapWithSize(1));
final Lattice lattice = Iterables.getOnlyElement(t.s.latticeMap.values());
final String l0 = "customer:[COUNT(), SUM(n10), SUM(n11)]";
assertThat(Iterables.getOnlyElement(t.s.latticeMap.keySet()), is(l0));
ImmutableList<Measure> measures = lattice.defaultMeasures;
assert measures.size() == 3;
Builder groupSetBuilder = ImmutableBitSet.builder();
measures.forEach(measure -> groupSetBuilder.addAll(measure.argBitSet()));
ImmutableBitSet groupSet = groupSetBuilder.build();
String sql = "SELECT \"num_children_at_home\" + 10 AS \"n10\", "
+ "\"num_children_at_home\" + 11 AS \"n11\", COUNT(*) AS \"m0\", "
+ "SUM(\"num_children_at_home\" + 10) AS \"m1\", "
+ "SUM(\"num_children_at_home\" + 11) AS \"m2\"\n"
+ "FROM \"foodmart\".\"customer\" AS \"customer\"\n"
+ "GROUP BY \"num_children_at_home\" + 10, \"num_children_at_home\" + 11";
assertThat(lattice.sql(groupSet, true, measures),
is(sql));
}

/** Test case for measure field involving a complex column operation with functions,
* for example sum(cast("num_children_at_home" as double) + 11). */
@Test void testExpressionLatticeSql3() throws Exception {
final Tester t = new Tester().foodmart().withEvolve(true);
final String q0 = "select\n"
+ " \"num_children_at_home\" + 12 as \"n12\",\n"
+ " sum(\"num_children_at_home\" + 10) as \"n10\",\n"
+ " sum(cast(\"num_children_at_home\" as double) + 11) as \"n11\",\n"
+ " count(*) as c\n"
+ "from \"customer\"\n"
+ "group by \"num_children_at_home\" + 12";
t.addQuery(q0);
assertThat(t.s.latticeMap, aMapWithSize(1));
final Lattice lattice = Iterables.getOnlyElement(t.s.latticeMap.values());
final String l0 = "customer:[COUNT(), SUM(n10), SUM(n11)]";
assertThat(Iterables.getOnlyElement(t.s.latticeMap.keySet()), is(l0));
ImmutableList<Measure> measures = lattice.defaultMeasures;
assert measures.size() == 3;
Builder groupSetBuilder = ImmutableBitSet.builder();
measures.forEach(measure -> groupSetBuilder.addAll(measure.argBitSet()));
ImmutableBitSet groupSet = groupSetBuilder.build();
String sql = "SELECT \"num_children_at_home\" + 10 AS \"n10\", "
+ "CAST(\"num_children_at_home\" AS DOUBLE) + 11 AS \"n11\", "
+ "COUNT(*) AS \"m0\", SUM(\"num_children_at_home\" + 10) AS \"m1\", "
+ "SUM(CAST(\"num_children_at_home\" AS DOUBLE) + 11) AS \"m2\"\n"
+ "FROM \"foodmart\".\"customer\" AS \"customer\"\n"
+ "GROUP BY \"num_children_at_home\" + 10, CAST(\"num_children_at_home\" AS DOUBLE) + 11";
assertThat(lattice.sql(groupSet, true, measures),
is(sql));
}

private void checkFoodmartSimpleJoin(CalciteAssert.SchemaSpec schemaSpec)
throws Exception {
final FrameworkConfig config = Frameworks.newConfigBuilder()
Expand Down

0 comments on commit 907ab4c

Please sign in to comment.