Skip to content

Commit

Permalink
fix(sql): allow CTEs in .sql method
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Apr 13, 2024
1 parent 2b7f7b1 commit 609c15f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:

cte = self._to_sqlglot(table)
parsed = sg.parse_one(query, read=dialect)
if parsed.args.get("with"):
parsed = sg.select(STAR).from_(parsed.subquery("_"))
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
sg.to_identifier(name, quoted=compiler.quoted), as_=cte, dialect=dialect
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def extract_ctes(node):

g = Graph.from_bfs(node, filter=~InstanceOf(dont_count))
for node, dependents in g.invert().items():
if isinstance(node, ops.View) or (
if isinstance(node, (ops.View, ops.SQLStringView)) or (
len(dependents) > 1 and isinstance(node, cte_types)
):
result.append(node)
Expand Down
38 changes: 38 additions & 0 deletions ibis/backends/tests/test_dot_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,41 @@ def test_bare_minimum(con, alltypes, df):
name = _NAMES.get(con.name, "functional_alltypes").replace('"', "")
expr = alltypes.sql(f'SELECT COUNT(*) AS "n" FROM "{name}"', dialect="duckdb")
assert expr.to_pandas().iat[0, 0] == len(df)


@dot_sql_never
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support selecting from quoted identifiers referencing CTEs",
)
@pytest.mark.notyet(
["druid"],
raises=KeyError,
reason="upstream does not preserve column names in schema inference",
)
def test_cte_basic(con, df):
t = con.tables.functional_alltypes
sql = 'with "x" as (select * from "functional_alltypes") select * from "x"'
expr = t.sql(sql, dialect="duckdb")
result = expr.execute()
tm.assert_frame_equal(result, df)


@dot_sql_never
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support selecting from quoted identifiers referencing CTEs",
)
@pytest.mark.notyet(
["druid"],
raises=KeyError,
reason="upstream does not preserve column names in schema inference",
)
def test_cte_with_alias(con, df):
t = con.tables.functional_alltypes
sql = 'with "x" as (select * from "foo") select * from "x"'
expr = t.alias("foo").sql(sql, dialect="duckdb")
result = expr.execute()
tm.assert_frame_equal(result, df)

0 comments on commit 609c15f

Please sign in to comment.