Skip to content

Commit

Permalink
feat(datafusion): add TimestampTruncate / fix broken extract time par…
Browse files Browse the repository at this point in the history
…t functions
  • Loading branch information
mesejo authored and cpcloud committed Oct 19, 2023
1 parent f53a523 commit 940ed21
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 7 deletions.
41 changes: 37 additions & 4 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ibis.expr.operations as ops
from ibis import NA
from ibis.backends.datafusion import registry
from ibis.common.temporal import IntervalUnit
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType

Expand Down Expand Up @@ -927,20 +928,36 @@ def extract_quarter(op, **kw):
@translate.register(ops.ExtractMinute)
def extract_minute(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("minute"), arg)

if op.arg.dtype.is_time():
return registry.UDFS["extract_minute_time"](arg)
elif op.arg.dtype.is_timestamp():
return df.functions.date_part(df.literal("minute"), arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.ExtractHour)
def extract_hour(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("hour"), arg)

if op.arg.dtype.is_time():
return registry.UDFS["extract_hour_time"](arg)
elif op.arg.dtype.is_timestamp():
return df.functions.date_part(df.literal("hour"), arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.ExtractMillisecond)
def extract_millisecond(op, **kw):
arg = translate(op.arg, **kw)

if op.arg.dtype.is_date():
if op.arg.dtype.is_time():
return registry.UDFS["extract_millisecond_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_millisecond_timestamp"](arg)
Expand All @@ -954,7 +971,7 @@ def extract_millisecond(op, **kw):
def extract_second(op, **kw):
arg = translate(op.arg, **kw)

if op.arg.dtype.is_date():
if op.arg.dtype.is_time():
return registry.UDFS["extract_second_time"](arg)
elif op.arg.dtype.is_timestamp():
return registry.UDFS["extract_second_timestamp"](arg)
Expand Down Expand Up @@ -1028,3 +1045,19 @@ def extract_epoch_seconds(op, **kw):
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)


@translate.register(ops.TimestampTruncate)
def timestamp_truncate(op, **kw):
arg = translate(op.arg, **kw)
unit = op.unit
if unit in (
IntervalUnit.MILLISECOND,
IntervalUnit.MICROSECOND,
IntervalUnit.NANOSECOND,
):
raise com.UnsupportedOperationError(
f"The function is not defined for time unit {unit}"
)

return df.functions.date_trunc(df.literal(unit.name.lower()), arg)
17 changes: 17 additions & 0 deletions ibis/backends/datafusion/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def extract_millisecond(array: pa.Array) -> pa.Array:
return pc.cast(pc.millisecond(array), pa.int32())


def extract_hour(array: pa.Array) -> pa.Array:
return pc.cast(pc.hour(array), pa.int32())


def extract_minute(array: pa.Array) -> pa.Array:
return pc.cast(pc.minute(array), pa.int32())


UDFS = {
"extract_microseconds_time": create_udf(
ops.ExtractMicrosecond,
Expand Down Expand Up @@ -111,4 +119,13 @@ def extract_millisecond(array: pa.Array) -> pa.Array:
input_types=[dt.timestamp],
name="extract_millisecond_timestamp",
),
"extract_hour_time": create_udf(
ops.ExtractHour, extract_hour, input_types=[dt.time], name="extract_hour_time"
),
"extract_minute_time": create_udf(
ops.ExtractMinute,
extract_minute,
input_types=[dt.time],
name="extract_minute_time",
),
}
38 changes: 38 additions & 0 deletions ibis/backends/datafusion/tests/test_temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from operator import methodcaller

import pytest
from pytest import param

import ibis


@pytest.mark.parametrize(
("func", "expected"),
[
param(
methodcaller("hour"),
14,
id="hour",
),
param(
methodcaller("minute"),
48,
id="minute",
),
param(
methodcaller("second"),
5,
id="second",
),
param(
methodcaller("millisecond"),
359,
id="millisecond",
),
],
)
def test_time_extract_literal(con, func, expected):
value = ibis.time("14:48:05.359")
assert con.execute(func(value).name("tmp")) == expected
22 changes: 19 additions & 3 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,14 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
"ms",
marks=[
pytest.mark.notimpl(
["clickhouse", "impala", "mysql", "pyspark", "sqlite"],
[
"clickhouse",
"impala",
"mysql",
"pyspark",
"sqlite",
"datafusion",
],
raises=com.UnsupportedOperationError,
),
pytest.mark.broken(
Expand All @@ -447,7 +454,15 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
"us",
marks=[
pytest.mark.notimpl(
["clickhouse", "impala", "mysql", "pyspark", "sqlite", "trino"],
[
"clickhouse",
"impala",
"mysql",
"pyspark",
"sqlite",
"trino",
"datafusion",
],
raises=com.UnsupportedOperationError,
),
pytest.mark.broken(
Expand All @@ -473,6 +488,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
"snowflake",
"trino",
"mssql",
"datafusion",
],
raises=com.UnsupportedOperationError,
),
Expand All @@ -485,7 +501,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df):
),
],
)
@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AttributeError,
Expand Down

0 comments on commit 940ed21

Please sign in to comment.