Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(expressions): speed up .describe() and .info() expression construction #9684

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import operator
import string
from collections import deque
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

Expand Down Expand Up @@ -1607,6 +1608,291 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
)
return sg.select(*columns_to_keep).from_(parent)

def visit_GenericInfo(self, op, *, parent, **_):
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
quoted = self.quoted
schema = op.parent.schema

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

aggs = deque()
for colname, pos in schema._name_locs.items():
typ = schema[colname]

col = sge.column(colname, table=table, quoted=quoted).is_(None)
isna = self.cast(col, dt.int32)

aggs.append(
sg.select(
sge.convert(colname).as_("name", quoted=quoted),
sge.convert(str(typ)).as_("type", quoted=quoted),
sge.convert(typ.nullable).as_("nullable", quoted=quoted),
self.agg.sum(isna).as_("nulls", quoted=quoted),
self.agg.sum(1 - isna).as_("non_nulls", quoted=quoted),
self.agg.avg(isna).as_("null_frac", quoted=quoted),
sge.convert(pos).as_("pos", quoted=quoted),
).from_(parent)
)

# rebalance aggs, this speeds up sqlglot compilation of huge unions
# significantly
while len(aggs) > 1:
left = aggs.popleft()
right = aggs.popleft()
aggs.append(sg.union(left, right, distinct=False))

unions = aggs.popleft()

assert not aggs, "not all unions processed"

return unions.order_by(sg.column("pos", quoted=quoted).asc())

def visit_FastInfo(self, op, *, parent, **_):
names = []
types = []
nullables = []
nullses = []
non_nullses = []
null_fracs = []
poses = []
quoted = self.quoted
schema = op.parent.schema

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

for colname, pos in schema._name_locs.items():
typ = schema[colname]

col = sge.column(colname, table=table, quoted=quoted).is_(None)
isna = self.cast(col, dt.int32)

names.append(sge.convert(colname))
types.append(sge.convert(str(typ)))
nullables.append(sge.convert(typ.nullable))
nullses.append(self.agg.sum(isna))
non_nullses.append(self.agg.sum(1 - isna))
null_fracs.append(self.agg.avg(isna))
poses.append(sge.convert(pos))

return (
sg.select(
self.f.explode(self.f.array(*names)).as_("name", quoted=quoted),
self.f.explode(self.f.array(*types)).as_("type", quoted=quoted),
self.f.explode(self.f.array(*nullables)).as_("nullable", quoted=quoted),
self.f.explode(self.f.array(*nullses)).as_("nulls", quoted=quoted),
self.f.explode(self.f.array(*non_nullses)).as_(
"non_nulls", quoted=quoted
),
self.f.explode(self.f.array(*null_fracs)).as_(
"null_frac", quoted=quoted
),
self.f.explode(self.f.array(*poses)).as_("pos", quoted=quoted),
)
.from_(parent)
.order_by(sg.column("pos", quoted=quoted).asc())
)

def visit_GenericDescribe(self, op, *, parent, quantile, **_):
quantile = sorted(quantile)
schema = op.parent.schema
opschema = op.schema
quoted = self.quoted

quantile_keys = tuple(
f"p{100 * q:.6f}".rstrip("0").rstrip(".") for q in quantile
)
default_quantiles = dict.fromkeys(quantile_keys, NULL)
table = sg.to_identifier(parent.alias_or_name, quoted=quoted)
aggs = deque()

for colname, pos in schema._name_locs.items():
col = sge.column(colname, table=table, quoted=quoted)
typ = schema[colname]

# statistics default to NULL
col_mean = col_std = col_min = col_max = col_mode = NULL
quantile_values = default_quantiles.copy()

if typ.is_numeric():
col_mean = self.agg.avg(col)
col_std = self.agg.stddev(col)
col_min = self.agg.min(col)
col_max = self.agg.max(col)
for key, q in zip(quantile_keys, quantile):
quantile_values[key] = sge.Quantile(
this=col, quantile=sge.convert(q)
)

elif typ.is_string():
if ops.Mode not in self.UNSUPPORTED_OPS:
col_mode = self.agg.mode(col)
else:
col_mode = self.cast(NULL, opschema["mode"])
elif typ.is_boolean():
col_mean = self.agg.avg(self.cast(col, dt.int32))
else:
# Will not calculate statistics for other types
continue

aggs.append(
sg.select(
sge.convert(colname).as_("name", quoted=quoted),
sge.convert(pos).as_("pos", quoted=quoted),
sge.convert(str(typ)).as_("type", quoted=quoted),
self.agg.count(col).as_("count", quoted=quoted),
self.agg.sum(self.cast(col.is_(NULL), dt.int32)).as_(
"nulls", quoted=quoted
),
self.agg.count(sge.Distinct(expressions=[col])).as_(
"unique", quoted=quoted
),
col_mode.as_("mode", quoted=quoted),
self.cast(col_mean, opschema["mean"]).as_("mean", quoted=quoted),
self.cast(col_std, opschema["std"]).as_("std", quoted=quoted),
self.cast(col_min, opschema["min"]).as_("min", quoted=quoted),
*(
self.cast(val, opschema[q]).as_(q, quoted=quoted)
for q, val in quantile_values.items()
),
self.cast(col_max, opschema["max"]).as_("max", quoted=quoted),
).from_(parent)
)

# rebalance aggs, this speeds up sqlglot compilation of huge unions
# significantly
while len(aggs) > 1:
left = aggs.popleft()
right = aggs.popleft()
aggs.append(sg.union(left, right, distinct=False))

unions = aggs.popleft()

assert not aggs, "not all unions processed"

return unions

def visit_FastDescribe(self, op, *, parent, quantile, **_):
quantile = sorted(quantile)
schema = op.parent.schema
quoted = self.quoted

name_locs = schema._name_locs
parent_schema_names = name_locs.keys()

names = list(map(sge.convert, parent_schema_names))
poses = list(map(sge.convert, name_locs.values()))
types = list(map(sge.convert, map(str, schema.values())))
counts = []
nulls = []
uniques = []
modes = []
means = []
stds = []
mins = []
quantiles = {}
maxs = []

quantile_keys = tuple(
f"p{100 * q:.6f}".rstrip("0").rstrip(".") for q in quantile
)
default_quantiles = dict.fromkeys(quantile_keys, NULL)
quantiles = {key: [] for key in quantile_keys}
table = sg.to_identifier(parent.alias_or_name, quoted=quoted)
opschema = op.schema

for colname in parent_schema_names:
col = sge.column(colname, table=table, quoted=quoted)
typ = schema[colname]

# statistics default to NULL
col_mean = col_std = col_min = col_max = col_mode = NULL
quantile_values = default_quantiles.copy()

if typ.is_numeric():
col_mean = self.agg.avg(col)
col_std = self.agg.stddev(col)
col_min = self.agg.min(col)
col_max = self.agg.max(col)
for key, q in zip(quantile_keys, quantile):
quantile_values[key] = sge.Quantile(
this=col, quantile=sge.convert(q)
)

elif typ.is_string():
if ops.Mode not in self.UNSUPPORTED_OPS:
col_mode = self.agg.mode(col)
else:
col_mode = self.cast(NULL, opschema["mode"])
elif typ.is_boolean():
col_mean = self.agg.avg(self.cast(col, dt.int32))
else:
# Will not calculate statistics for other types
continue

counts.append(self.agg.count(col))
nulls.append(self.agg.sum(self.cast(col.is_(NULL), dt.int32)))
uniques.append(self.agg.count(sge.Distinct(expressions=[col])))
modes.append(col_mode)
means.append(col_mean)
stds.append(col_std)
mins.append(col_min)

for q, val in quantile_values.items():
quantiles[q].append(val)

maxs.append(col_max)

return sg.select(
self.f.explode(
self.cast(self.f.array(*names), dt.Array(opschema["name"]))
).as_("name", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*poses), dt.Array(opschema["pos"]))
).as_("pos", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*types), dt.Array(opschema["type"]))
).as_("type", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*counts), dt.Array(opschema["count"]))
).as_("count", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*nulls), dt.Array(opschema["nulls"]))
).as_("nulls", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*uniques), dt.Array(opschema["unique"]))
).as_("unique", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*modes), dt.Array(opschema["mode"]))
).as_("mode", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*means), dt.Array(opschema["mean"]))
).as_("mean", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*stds), dt.Array(opschema["std"]))
).as_("std", quoted=quoted),
self.f.explode(
self.cast(self.f.array(*mins), dt.Array(opschema["min"]))
).as_("min", quoted=quoted),
*(
self.f.explode(
self.cast(self.f.array(*vals), dt.Array(opschema[q]))
).as_(q, quoted=quoted)
for q, vals in quantiles.items()
),
self.f.explode(
self.cast(self.f.array(*maxs), dt.Array(opschema["max"]))
).as_("max", quoted=quoted),
).from_(parent)

def visit_Info(self, op, *, parent):
if ops.Unnest in self.UNSUPPORTED_OPS:
return self.visit_GenericInfo(op, parent=parent)
return self.visit_FastInfo(op, parent=parent)

def visit_Describe(self, op, *, parent, quantile):
if ops.Unnest in self.UNSUPPORTED_OPS:
return self.visit_GenericDescribe(op, parent=parent, quantile=quantile)
return self.visit_FastDescribe(op, parent=parent, quantile=quantile)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
25 changes: 23 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Iterator, Mapping


class ClickhouseAggGen(AggGen):
class ClickHouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None, order_by=()):
if order_by:
raise com.UnsupportedOperationError(
Expand All @@ -33,14 +33,26 @@ def aggregate(self, compiler, name, *args, where=None, order_by=()):
args += (where,)
return compiler.f[name](*args, dialect=compiler.dialect)

def mode(self, arg, where=None):
func = "topK"

params = [arg]
if where is not None:
func += "If"
params.append(where)

return sge.ParameterizedAgg(
this=func, expressions=[sge.convert(1)], params=params
)[1]


class ClickHouseCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = ClickHouse
type_mapper = ClickHouseType

agg = ClickhouseAggGen()
agg = ClickHouseAggGen()

supports_qualify = True

Expand Down Expand Up @@ -104,6 +116,7 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.MapMerge: "mapUpdate",
ops.MapValues: "mapValues",
ops.Median: "quantileExactExclusive",
ops.Mode: "mode",
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
Expand Down Expand Up @@ -816,5 +829,13 @@ def visit_MapContains(self, op, *, arg, key):
sg.or_(arg.is_(NULL), key.is_(NULL)), NULL, self.f.mapContains(arg, key)
)

def visit_Info(self, op, *, parent):
# clickhouse explode is not pairwise, so fall back to generic impl
return self.visit_GenericInfo(op, parent=parent)

def visit_Describe(self, op, *, parent, quantile):
# clickhouse explode is not pairwise, so fall back to generic impl
return self.visit_GenericDescribe(op, parent=parent, quantile=quantile)


compiler = ClickHouseCompiler()
Loading
Loading