From 609c15f0ae4220701d8aaad21f92d9cc244c9287 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 13 Apr 2024 11:48:20 -0400 Subject: [PATCH] fix(sql): allow CTEs in `.sql` method --- ibis/backends/sql/__init__.py | 2 ++ ibis/backends/sql/rewrites.py | 2 +- ibis/backends/tests/test_dot_sql.py | 38 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index 3ad0ead0387c..91f7598c43ee 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -221,6 +221,8 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: cte = self._to_sqlglot(table) parsed = sg.parse_one(query, read=dialect) + if parsed.args.get("with"): + parsed = sg.select(STAR).from_(parsed.subquery("_")) parsed.args["with"] = cte.args.pop("with", []) parsed = parsed.with_( sg.to_identifier(name, quoted=compiler.quoted), as_=cte, dialect=dialect diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index f8efaa19b264..5b6747983119 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -187,7 +187,7 @@ def extract_ctes(node): g = Graph.from_bfs(node, filter=~InstanceOf(dont_count)) for node, dependents in g.invert().items(): - if isinstance(node, ops.View) or ( + if isinstance(node, (ops.View, ops.SQLStringView)) or ( len(dependents) > 1 and isinstance(node, cte_types) ): result.append(node) diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 886dd9a85e66..ea7cc9d50d6c 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -358,3 +358,41 @@ def test_bare_minimum(con, alltypes, df): name = _NAMES.get(con.name, "functional_alltypes").replace('"', "") expr = alltypes.sql(f'SELECT COUNT(*) AS "n" FROM "{name}"', dialect="duckdb") assert expr.to_pandas().iat[0, 0] == len(df) + + +@dot_sql_never +@pytest.mark.notyet( + ["polars"], + raises=PolarsComputeError, + reason="polars doesn't support selecting from quoted identifiers referencing CTEs", +) +@pytest.mark.notyet( + ["druid"], + raises=KeyError, + reason="upstream does not preserve column names in schema inference", +) +def test_cte_basic(con, df): + t = con.tables.functional_alltypes + sql = 'with "x" as (select * from "functional_alltypes") select * from "x"' + expr = t.sql(sql, dialect="duckdb") + result = expr.execute() + tm.assert_frame_equal(result, df) + + +@dot_sql_never +@pytest.mark.notyet( + ["polars"], + raises=PolarsComputeError, + reason="polars doesn't support selecting from quoted identifiers referencing CTEs", +) +@pytest.mark.notyet( + ["druid"], + raises=KeyError, + reason="upstream does not preserve column names in schema inference", +) +def test_cte_with_alias(con, df): + t = con.tables.functional_alltypes + sql = 'with "x" as (select * from "foo") select * from "x"' + expr = t.alias("foo").sql(sql, dialect="duckdb") + result = expr.execute() + tm.assert_frame_equal(result, df)