Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add test for impure function correlation behavior #9014

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,7 @@ def _register_udf(self, udf_node: ops.ScalarUDF):
for param in udf_node.__signature__.parameters.values()
]
output_type = type_mapper.to_string(udf_node.dtype)
config = udf_node.__config__

def register_udf(con):
return con.create_function(
Expand All @@ -1651,6 +1652,7 @@ def register_udf(con):
input_types,
output_type,
type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__],
**config,
)

return register_udf
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF):
type_mapper = self.type_mapper
argnames = udf_node.argnames
return """\
CREATE OR REPLACE FUNCTION {ident}({signature})
RETURNS {return_type}
LANGUAGE {language}
AS $$
{source}
return {name}({args})
$$""".format(
CREATE OR REPLACE FUNCTION {ident}({signature})
RETURNS {return_type}
LANGUAGE {language}
AS $$
{source}
return {name}({args})
$$""".format(
name=type(udf_node).__name__,
ident=self.__sql_name__(udf_node),
signature=", ".join(
Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import operator
import sys
from collections.abc import Mapping
from functools import reduce
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -230,6 +231,9 @@ def complexity(node):
def accum(node, *args):
if isinstance(node, ops.Field):
return 1
elif isinstance(node, ops.Impure):
# consider (potentially) impure functions maximally complex
return sys.maxsize
else:
return 1 + sum(args)

Expand Down
221 changes: 221 additions & 0 deletions ibis/backends/tests/test_impure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from __future__ import annotations

import sys

import pytest

import ibis
import ibis.common.exceptions as com
from ibis import _
from ibis.backends.tests.errors import Py4JJavaError

tm = pytest.importorskip("pandas.testing")

pytestmark = pytest.mark.xdist_group("impure")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concurrent execution of CREATE OR REPLACE FUNCTION in postgres doesn't seem to work. This ensures that all tests in this module run in the same process as long as --dist=loadgroup is passed, which it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, that seems like something we shouldn't worry about further. Maybe add a comment here? Also fine to not, if someone removes it they will find out from failing tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment, because I already forgot when revisiting this in review 😅 !


no_randoms = [
pytest.mark.notimpl(
["polars", "druid", "risingwave"], raises=com.OperationNotDefinedError
),
]

no_udfs = [
pytest.mark.notyet("datafusion", raises=NotImplementedError),
pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"druid",
"exasol",
"impala",
"mssql",
"mysql",
"oracle",
"trino",
"risingwave",
]
),
pytest.mark.notyet(
"flink",
condition=sys.version_info >= (3, 11),
raises=Py4JJavaError,
reason="Docker image has Python 3.10, results in `cloudpickle` version mismatch",
),
]

no_uuids = [
pytest.mark.notimpl(
["druid", "exasol", "oracle", "polars", "pyspark", "risingwave"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet("mssql", reason="Unrelated bug: Incorrect syntax near '('"),
]


@ibis.udf.scalar.python(side_effects=True)
def my_random(x: float) -> float:
# need to make the whole UDF self-contained for postgres to work
import random

return random.random() # noqa: S311


mark_impures = pytest.mark.parametrize(
"impure",
[
pytest.param(lambda _: ibis.random(), marks=no_randoms, id="random"),
pytest.param(
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
*no_uuids,
pytest.mark.notyet("impala", reason="instances are uncorrelated"),
],
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
marks=[
*no_udfs,
pytest.mark.notyet(["flink"], reason="instances are uncorrelated"),
],
id="udf",
),
],
)


# You can work around this by .cache()ing the table.
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_impure_correlated(alltypes, impure):
# An "impure" expression is random(), uuid(), or some other non-deterministic UDF.
# If we evaluate it for two different rows in the same relation,
# we might get different results. This is expected.
# But, as soon as we .select() it into a new relation, then that "locks in" the
# value, and any further references to it will be the same.
# eg if you look at the following SQL:
# WITH
# t AS (SELECT random() AS common)
# SELECT common as x, common as y FROM t
# Then both x and y should have the same value.
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
expr = alltypes.select(common=impure(alltypes)).select(x=_.common, y=_.common)
df = expr.execute()
tm.assert_series_equal(df.x, df.y, check_names=False)


# You can work around this by .cache()ing the table.
@pytest.mark.notyet("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_chained_selections(alltypes, impure):
# https://github.com/ibis-project/ibis/issues/8921#issue-2234327722
# This is a slightly more complex version of test_impure_correlated.
# consider this SQL:
# WITH
# t AS (SELECT random() AS num)
# SELECT num, num > 0.5 AS isbig FROM t
# We would expect that the value of num and isbig are consistent,
# since we "lock in" the value of num by selecting it into t.
t = alltypes.select(num=impure(alltypes))
t = t.mutate(isbig=(t.num > 0.5))
df = t.execute()
df["expected"] = df.num > 0.5
tm.assert_series_equal(df.isbig, df.expected, check_names=False)


impure_params_uncorrelated = pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
marks=[
*no_randoms,
pytest.mark.notyet(["impala"], reason="instances are correlated"),
],
id="random",
),
pytest.param(
# make this a float so we can compare to .5
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
*no_uuids,
pytest.mark.notyet(["mysql"], reason="instances are correlated"),
],
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
marks=[
*no_udfs,
# no "impure" argument for pyspark yet
pytest.mark.notimpl("pyspark"),
Copy link
Contributor Author

@NickCrews NickCrews Aug 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

notyet instead of notimpl, since it's a problem on the backend side? And move the comment into the reason kwarg?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, we just don't pass it through, unless pyspark doesn't have the ability to set this property.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, thanks. Looks like it is implemented in spark, we just need to expose it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this in a follow-up. Lucky for us, each engine appears to use a slightly different way of specifying whether a UDF is deterministic 🙄

],
id="udf",
),
],
)


# You can work around this by doing .select().cache().select()
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
@impure_params_uncorrelated
def test_impure_uncorrelated_different_id(alltypes, impure):
# This is the opposite of test_impure_correlated.
# If we evaluate an impure expression for two different rows in the same relation,
# the should be uncorrelated.
# eg if you look at the following SQL:
# select random() as x, random() as y
# Then x and y should be uncorrelated.
expr = alltypes.select(x=impure(alltypes), y=impure(alltypes))
df = expr.execute()
assert (df.x != df.y).any()


# You can work around this by doing .select().cache().select()
@pytest.mark.notyet(["clickhouse"], reason="instances are correlated")
@impure_params_uncorrelated
def test_impure_uncorrelated_same_id(alltypes, impure):
# Similar to test_impure_uncorrelated_different_id, but the two expressions
# have the same ID. Still, they should be uncorrelated.
common = impure(alltypes)
expr = alltypes.select(x=common, y=common)
df = expr.execute()
assert (df.x != df.y).any()


@pytest.mark.notyet(
[
"duckdb",
"clickhouse",
"datafusion",
"mysql",
"impala",
"mssql",
"trino",
"flink",
"bigquery",
],
raises=AssertionError,
reason="instances are not correlated but ideally they would be",
)
@pytest.mark.notyet(
["sqlite"],
raises=AssertionError,
reason="instances are *sometimes* correlated but ideally they would always be",
strict=False,
)
@pytest.mark.notimpl(
["polars", "risingwave", "druid", "exasol", "oracle", "pyspark"],
raises=com.OperationNotDefinedError,
)
def test_self_join_with_generated_keys(con):
# Even with CTEs in the generated SQL, the backends still
# materialize a new value every time it is referenced.
# This isn't ideal behavior, but there is nothing we can do about it
# on the ibis side. The best you can do is to .cache() the table
# right after you assign the uuid().
# https://github.com/ibis-project/ibis/pull/9014#issuecomment-2399449665
left = ibis.memtable({"idx": list(range(5))}).mutate(key=ibis.uuid())
right = left.filter(left.idx < 3)
expr = left.join(right, "key")
result = con.execute(expr.count())
assert result == 3
6 changes: 5 additions & 1 deletion ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
ops.StringContains: "contains",
ops.StringSQLILike: "ilike",
ops.StringSQLLike: "like",
ops.TimestampNow: "now",
}


Expand Down Expand Up @@ -84,6 +83,11 @@ def translate(op, *args, **kwargs):
raise NotImplementedError(op)


@translate.register(ops.TimestampNow)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed for this PR or just an unrelated fixup? (looks like the right change regardless)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's needed for this PR, but I'm not entirely sure why. I'll poke around a bit.

def now(_):
return "ibis.now()"


@translate.register(ops.Value)
def value(op, *args, **kwargs):
method = _get_method_name(op)
Expand Down
6 changes: 4 additions & 2 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,19 @@ class Impure(Value):


@public
class TimestampNow(Constant):
class TimestampNow(Impure):
"""Return the current timestamp."""

dtype = dt.timestamp
shape = ds.scalar


@public
class DateNow(Constant):
class DateNow(Impure):
"""Return the current date."""

dtype = dt.date
shape = ds.scalar


@public
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class InputType(enum.Enum):


@public
class ScalarUDF(ops.Value):
class ScalarUDF(ops.Impure):
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
@attribute
def shape(self):
if not (args := getattr(self, "args")): # noqa: B009
Expand All @@ -65,7 +65,7 @@ def shape(self):


@public
class AggUDF(ops.Reduction):
class AggUDF(ops.Reduction, ops.Impure):
where: Optional[ops.Value[dt.Boolean]] = None


Expand Down
Loading