From 6a8fb76c8f71fc5929b10d3d949dfdfbc652b410 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:16:24 -0400 Subject: [PATCH] chore(trino): use generic summary ops --- ibis/backends/sql/compilers/sqlite.py | 2 +- ibis/backends/sql/compilers/trino.py | 8 ++++++++ ibis/backends/sql/dialects.py | 17 ++++++++++++++++- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index 62ce583062da..4ad1debdcda8 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -450,7 +450,7 @@ def visit_DayOfWeekName(self, op, *, arg): ) def visit_Xor(self, op, *, left, right): - return (left.or_(right)).and_(sg.not_(left.and_(right))) + return left.or_(right).and_(sg.not_(left.and_(right))) def visit_NonNullLiteral(self, op, *, value, dtype): if dtype.is_binary(): diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 4a19b9a37436..5836c8b5ee1b 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -659,5 +659,13 @@ def visit_ArraySum(self, op, *, arg): def visit_ArrayMean(self, op, *, arg): return self.visit_ArraySumAgg(op, arg=arg, output=operator.truediv) + def visit_Info(self, op, *, parent): + # unnest cannot contain aggregates + return self.visit_GenericInfo(op, parent=parent) + + def visit_Describe(self, op, *, parent, quantile): + # unnest cannot contain aggregates + return self.visit_GenericDescribe(op, parent=parent, quantile=quantile) + compiler = TrinoCompiler() diff --git a/ibis/backends/sql/dialects.py b/ibis/backends/sql/dialects.py index 217fdf34e1e2..c064a88fac77 100644 --- a/ibis/backends/sql/dialects.py +++ b/ibis/backends/sql/dialects.py @@ -20,7 +20,7 @@ Trino, ) from sqlglot.dialects.dialect import rename_func -from sqlglot.helper import find_new_name, seq_get +from sqlglot.helper import find_new_name, flatten, seq_get ClickHouse.Generator.TRANSFORMS |= { sge.ArraySize: rename_func("length"), @@ -440,7 +440,22 @@ class Generator(Postgres.Generator): sge.Levenshtein: rename_func("editdistance"), } + +# return lambda self, expression: self.func(name, *flatten(expression.args.values())) SQLite.Generator.TYPE_MAPPING |= {sge.DataType.Type.BOOLEAN: "BOOLEAN"} +SQLite.Generator.TRANSFORMS |= { + sge.Stddev: lambda self, e: self.func( + "sqrt", self.func("_ibis_var_sample", *flatten(e.args.values())) + ), + sge.StddevSamp: lambda self, e: self.func( + "sqrt", self.func("_ibis_var_sample", *flatten(e.args.values())) + ), + sge.StddevPop: lambda self, e: self.func( + "sqrt", self.func("_ibis_var_pop", *flatten(e.args.values())) + ), + sge.Variance: rename_func("_ibis_var_samp"), + sge.VariancePop: rename_func("_ibis_var_pop"), +} # TODO(cpcloud): remove this hack once