Skip to content

Commit

Permalink
perf(expressions): speed up .describe() and .info() expression co…
Browse files Browse the repository at this point in the history
…nstruction
  • Loading branch information
cpcloud committed Jul 24, 2024
1 parent eba1b76 commit 2cdda1a
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 90 deletions.
268 changes: 268 additions & 0 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import operator
import string
from collections import deque
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

Expand All @@ -22,7 +23,9 @@
LastValue,
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
describe_to_generic_describe,
empty_in_values_right_side,
info_to_generic_info,
lower_bucket,
lower_capitalize,
lower_sample,
Expand Down Expand Up @@ -238,6 +241,8 @@ class SQLGlotCompiler(abc.ABC):
add_order_by_to_empty_ranking_window_functions,
one_to_zero_index,
add_one_to_nth_value_input,
info_to_generic_info,
describe_to_generic_describe,
)
"""A sequence of rewrites to apply to the expression tree before compilation."""

Expand Down Expand Up @@ -1527,6 +1532,269 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
)
return sg.select(*columns_to_keep).from_(parent)

def visit_GenericInfo(self, op, *, parent, **_):
quoted = self.quoted
schema = op.parent.schema

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

aggs = deque()
for colname, pos in schema._name_locs.items():
typ = schema[colname]

col = sge.column(colname, table=table, quoted=quoted).is_(None)
isna = self.cast(col, dt.int32)

aggs.append(
sg.select(
sge.convert(colname).as_("name", quoted=quoted),
sge.convert(str(typ)).as_("type", quoted=quoted),
sge.convert(typ.nullable).as_("nullable", quoted=quoted),
self.f.sum(isna).as_("nulls", quoted=quoted),
self.f.sum(1 - isna).as_("non_nulls", quoted=quoted),
self.f.avg(isna).as_("null_frac", quoted=quoted),
sge.convert(pos).as_("pos", quoted=quoted),
).from_(parent)
)

# rebalance aggs, this speeds up sqlglot compilation of huge unions
# significantly
while len(aggs) > 1:
left = aggs.popleft()
right = aggs.popleft()
aggs.append(sg.union(left, right, distinct=False))

unions = aggs.popleft()

assert not aggs, "not all unions processed"

return unions.order_by(sg.column("pos", quoted=quoted).asc())

def visit_FastInfo(self, op, *, parent, **_):
names = []
types = []
nullables = []
nullses = []
non_nullses = []
null_fracs = []
poses = []
quoted = self.quoted
schema = op.parent.schema

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

for colname, pos in schema._name_locs.items():
typ = schema[colname]

col = sge.column(colname, table=table, quoted=quoted).is_(None)
isna = self.cast(col, dt.int32)

names.append(sge.convert(colname))
types.append(sge.convert(str(typ)))
nullables.append(sge.convert(typ.nullable))
nullses.append(self.f.sum(isna))
non_nullses.append(self.f.sum(1 - isna))
null_fracs.append(self.f.avg(isna))
poses.append(sge.convert(pos))

return (
sg.select(
self.f.unnest(self.f.array(*names)).as_("name", quoted=quoted),
self.f.unnest(self.f.array(*types)).as_("type", quoted=quoted),
self.f.unnest(self.f.array(*nullables)).as_("nullable", quoted=quoted),
self.f.unnest(self.f.array(*nullses)).as_("nulls", quoted=quoted),
self.f.unnest(self.f.array(*non_nullses)).as_(
"non_nulls", quoted=quoted
),
self.f.unnest(self.f.array(*null_fracs)).as_(
"null_frac", quoted=quoted
),
self.f.unnest(self.f.array(*poses)).as_("pos", quoted=quoted),
)
.from_(parent)
.order_by(sg.column("pos", quoted=quoted).asc())
)

def visit_GenericDescribe(self, op, *, parent, quantile, **_):
quantile = sorted(quantile)
schema = op.parent.schema
quoted = self.quoted

quantile_keys = tuple(
f"p{100 * q:.6f}".rstrip("0").rstrip(".") for q in quantile
)
default_quantiles = dict.fromkeys(quantile_keys, NULL)
table = sg.to_identifier(parent.alias_or_name, quoted=quoted)
aggs = deque()

for colname, pos in schema._name_locs.items():
col = sge.column(colname, table=table, quoted=quoted)
typ = schema[colname]

# statistics default to NULL
col_mean = col_std = col_min = col_max = col_mode = NULL
quantile_values = default_quantiles.copy()

if typ.is_numeric():
col_mean = self.f.avg(col)
col_std = self.f.stddev(col)
col_min = self.f.min(col)
col_max = self.f.max(col)
for key, q in zip(quantile_keys, quantile):
quantile_values[key] = sge.Quantile(
this=col, quantile=sge.convert(q)
)

elif typ.is_string():
col_mode = self.f.mode(col)
elif typ.is_boolean():
col_mean = self.f.avg(self.f.cast(col, dt.int32))
else:
# Will not calculate statistics for other types
continue

aggs.append(
sg.select(
sge.convert(colname).as_("name", quoted=quoted),
sge.convert(pos).as_("pos", quoted=quoted),
sge.convert(str(typ)).as_("type", quoted=quoted),
self.f.count(col).as_("count", quoted=quoted),
self.f.sum(self.cast(col.is_(NULL), dt.int32)).as_(
"nulls", quoted=quoted
),
self.f.count(sge.Distinct(expressions=[col])).as_(
"unique", quoted=quoted
),
col_mode.as_("mode", quoted=quoted),
col_mean.as_("mean", quoted=quoted),
col_std.as_("std", quoted=quoted),
col_min.as_("min", quoted=quoted),
*(val.as_(q, quoted=quoted) for q, val in quantile_values.items()),
col_max.as_("max", quoted=quoted),
).from_(parent)
)

# rebalance aggs, this speeds up sqlglot compilation of huge unions
# significantly
while len(aggs) > 1:
left = aggs.popleft()
right = aggs.popleft()
aggs.append(sg.union(left, right, distinct=False))

unions = aggs.popleft()

assert not aggs, "not all unions processed"

return unions

def visit_FastDescribe(self, op, *, parent, quantile, **_):
quantile = sorted(quantile)
schema = op.parent.schema
quoted = self.quoted

names = list(map(sge.convert, schema._name_locs.keys()))
poses = list(map(sge.convert, schema._name_locs.values()))
types = list(map(sge.convert, map(str, schema.values())))
counts = []
nulls = []
uniques = []
modes = []
means = []
stds = []
mins = []
quantiles = {}
maxs = []

quantile_keys = tuple(
f"p{100 * q:.6f}".rstrip("0").rstrip(".") for q in quantile
)
default_quantiles = dict.fromkeys(quantile_keys, NULL)
quantiles = {key: [] for key in quantile_keys}
table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

for colname in schema._name_locs.keys():
col = sge.column(colname, table=table, quoted=quoted)
typ = schema[colname]

# statistics default to NULL
col_mean = col_std = col_min = col_max = col_mode = NULL
quantile_values = default_quantiles.copy()

if typ.is_numeric():
col_mean = self.f.avg(col)
col_std = self.f.stddev(col)
col_min = self.f.min(col)
col_max = self.f.max(col)
for key, q in zip(quantile_keys, quantile):
quantile_values[key] = sge.Quantile(
this=col, quantile=sge.convert(q)
)

elif typ.is_string():
col_mode = self.f.mode(col)
elif typ.is_boolean():
col_mean = self.f.avg(self.f.cast(col, dt.int32))
else:
# Will not calculate statistics for other types
continue

counts.append(self.f.count(col))
nulls.append(self.f.sum(self.cast(col.is_(NULL), dt.int32)))
uniques.append(self.f.count(sge.Distinct(expressions=[col])))
modes.append(col_mode)
means.append(col_mean)
stds.append(col_std)
mins.append(col_min)

for q, val in quantile_values.items():
quantiles[q].append(val)

maxs.append(col_max)

opschema = op.schema

return sg.select(
self.f.unnest(
self.cast(self.f.array(*names), dt.Array(opschema["name"]))
).as_("name", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*poses), dt.Array(opschema["pos"]))
).as_("pos", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*types), dt.Array(opschema["type"]))
).as_("type", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*counts), dt.Array(opschema["count"]))
).as_("count", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*nulls), dt.Array(opschema["nulls"]))
).as_("nulls", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*uniques), dt.Array(opschema["unique"]))
).as_("unique", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*modes), dt.Array(opschema["mode"]))
).as_("mode", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*means), dt.Array(opschema["mean"]))
).as_("mean", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*stds), dt.Array(opschema["std"]))
).as_("std", quoted=quoted),
self.f.unnest(
self.cast(self.f.array(*mins), dt.Array(opschema["min"]))
).as_("min", quoted=quoted),
*(
self.f.unnest(
self.cast(self.f.array(*vals), dt.Array(opschema[q]))
).as_(q, quoted=quoted)
for q, vals in quantiles.items()
),
self.f.unnest(
self.cast(self.f.array(*maxs), dt.Array(opschema["max"]))
).as_("max", quoted=quoted),
).from_(parent)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import ibis.expr.operations as ops
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.backends.sql.rewrites import (
describe_to_fast_describe,
exclude_nulls_from_array_collect,
info_to_fast_info,
)
from ibis.util import gen_name

_INTERVAL_SUFFIXES = {
Expand All @@ -37,6 +41,8 @@ class DuckDBCompiler(SQLGlotCompiler):

rewrites = (
exclude_nulls_from_array_collect,
info_to_fast_info,
describe_to_fast_describe,
*SQLGlotCompiler.rewrites,
)

Expand Down
74 changes: 74 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,80 @@ def dtype(self):
return self.arg.dtype


@public
class GenericInfo(ops.Relation):
"""Generic info node."""

parent: ops.Relation
schema: Schema

@attribute
def values(self):
return {}


@public
class FastInfo(ops.Relation):
"""Fast info node for backends that support arrays."""

parent: ops.Relation
schema: Schema

@attribute
def values(self):
return {}


@replace(p.Info)
def info_to_generic_info(_, **kwargs):
"""Convert Info node to GenericInfo node."""
return GenericInfo(parent=_.parent, schema=_.schema)


@replace(p.Info)
def info_to_fast_info(_, **kwargs):
"""Convert Info node to GenericInfo node."""
return FastInfo(parent=_.parent, schema=_.schema)


@public
class GenericDescribe(ops.Relation):
"""Generic describe node."""

parent: ops.Relation
quantile: VarTuple[float]
schema: Schema

@attribute
def values(self):
return {}


@public
class FastDescribe(ops.Relation):
"""Fast describe node for backends that support arrays."""

parent: ops.Relation
quantile: VarTuple[float]
schema: Schema

@attribute
def values(self):
return {}


@replace(p.Describe)
def describe_to_generic_describe(_, **kwargs):
"""Convert Info node to GenericInfo node."""
return GenericDescribe(parent=_.parent, quantile=_.quantile, schema=_.schema)


@replace(p.Describe)
def describe_to_fast_describe(_, **kwargs):
"""Convert Info node to GenericInfo node."""
return FastDescribe(parent=_.parent, quantile=_.quantile, schema=_.schema)


# TODO(kszucs): there is a better strategy to rewrite the relational operations
# to Select nodes by wrapping the leaf nodes in a Select node and then merging
# Project, Filter, Sort, etc. incrementally into the Select node. This way we
Expand Down
Loading

0 comments on commit 2cdda1a

Please sign in to comment.