Skip to content

Commit

Permalink
Int64 as default type for make_array function empty or null case (#10790
Browse files Browse the repository at this point in the history
)

* set default type i64

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Jun 6, 2024
1 parent c580ef4 commit 053b53e
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 48 deletions.
10 changes: 3 additions & 7 deletions datafusion/functions-array/src/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::utils::make_scalar_function;
use arrow_array::{ArrayRef, BooleanArray, OffsetSizeTrait};
use arrow_schema::DataType;
use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List};
use datafusion_common::cast::{as_generic_list_array, as_null_array};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
Expand Down Expand Up @@ -85,12 +85,7 @@ pub fn array_empty_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("array_empty expects one argument");
}

if as_null_array(&args[0]).is_ok() {
// Make sure to return Boolean type.
return Ok(Arc::new(BooleanArray::new_null(args[0].len())));
}
let array_type = args[0].data_type();

match array_type {
List(_) => general_array_empty::<i32>(&args[0]),
LargeList(_) => general_array_empty::<i64>(&args[0]),
Expand All @@ -100,9 +95,10 @@ pub fn array_empty_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

fn general_array_empty<O: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef> {
let array = as_generic_list_array::<O>(array)?;

let builder = array
.iter()
.map(|arr| arr.map(|arr| arr.len() == arr.null_count()))
.map(|arr| arr.map(|arr| arr.is_empty()))
.collect::<BooleanArray>();
Ok(Arc::new(builder))
}
19 changes: 13 additions & 6 deletions datafusion/functions-array/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ impl ScalarUDFImpl for MakeArray {

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types.len() {
0 => Ok(DataType::List(Arc::new(Field::new(
"item",
DataType::Null,
true,
)))),
0 => Ok(empty_array_type()),
_ => {
let mut expr_type = DataType::Null;
for arg_type in arg_types {
Expand All @@ -94,6 +90,10 @@ impl ScalarUDFImpl for MakeArray {
}
}

if expr_type.is_null() {
expr_type = DataType::Int64;
}

Ok(List(Arc::new(Field::new("item", expr_type, true))))
}
}
Expand Down Expand Up @@ -131,6 +131,11 @@ impl ScalarUDFImpl for MakeArray {
}
}

// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
}

/// `make_array_inner` is the implementation of the `make_array` function.
/// Constructs an array using the input `data` as `ArrayRef`.
/// Returns a reference-counted `Array` instance result.
Expand All @@ -147,7 +152,9 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
match data_type {
// Either an empty array or all nulls:
Null => {
let array = new_null_array(&Null, arrays.iter().map(|a| a.len()).sum());
let length = arrays.iter().map(|a| a.len()).sum();
// By default Int64
let array = new_null_array(&DataType::Int64, length);
Ok(Arc::new(array_into_list_array(array)))
}
LargeList(..) => array_array::<i64>(arrays, data_type),
Expand Down
15 changes: 13 additions & 2 deletions datafusion/functions-array/src/set_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.

use crate::make_array::make_array_inner;
use crate::make_array::{empty_array_type, make_array_inner};
use crate::utils::make_scalar_function;
use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait};
use arrow::buffer::OffsetBuffer;
Expand Down Expand Up @@ -135,7 +135,7 @@ impl ScalarUDFImpl for ArrayIntersect {
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match (arg_types[0].clone(), arg_types[1].clone()) {
(Null, Null) | (Null, _) => Ok(Null),
(_, Null) => Ok(List(Arc::new(Field::new("item", Null, true)))),
(_, Null) => Ok(empty_array_type()),
(dt, _) => Ok(dt),
}
}
Expand Down Expand Up @@ -259,6 +259,17 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
return general_array_distinct::<OffsetSize>(l, &field);
}

// Handle empty array at rhs case
// array_union(arr, []) -> arr;
// array_intersect(arr, []) -> [];
if r.value_length(0).is_zero() {
if set_op == SetOp::Union {
return Ok(Arc::new(l.clone()) as ArrayRef);
} else {
return Ok(Arc::new(r.clone()) as ArrayRef);
}
}

if l.value_type() != r.value_type() {
return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
}
Expand Down
81 changes: 52 additions & 29 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ AS VALUES
(arrow_cast(make_array([[1,2]], [[3, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2], [3, 4]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2,3], [1]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([2], [3]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], [[]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([2], [3]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], [[]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')),
(arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))'))
;
Expand Down Expand Up @@ -2038,6 +2038,13 @@ NULL
[, 51, 52, 54, 55, 56, 57, 58, 59, 60]
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70]

# test with empty array
query ?
select array_sort([]);
----
[]

# test with empty row, the row that does not match the condition has row count 0
statement ok
create table t1(a int, b int) as values (100, 1), (101, 2), (102, 3), (101, 2);

Expand Down Expand Up @@ -2083,10 +2090,10 @@ select

query ????
select
array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4),
array_append(arrow_cast(make_array(), 'LargeList(Null)'), null),
array_append(arrow_cast(make_array(), 'LargeList(Int64)'), 4),
array_append(arrow_cast(make_array(), 'LargeList(Int64)'), null),
array_append(arrow_cast(make_array(1, null, 3), 'LargeList(Int64)'), 4),
array_append(arrow_cast(make_array(null, null), 'LargeList(Null)'), 1)
array_append(arrow_cast(make_array(null, null), 'LargeList(Int64)'), 1)
;
----
[4] [] [1, , 3, 4] [, , 1]
Expand Down Expand Up @@ -2567,7 +2574,7 @@ query ????
select
array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5),
array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3),
array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3),
array_repeat(arrow_cast([null, null], 'LargeList(Int64)'), 3),
array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2);
----
[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
Expand Down Expand Up @@ -2630,6 +2637,12 @@ drop table large_array_repeat_table;

## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)

# test with empty array
query ?
select array_concat([]);
----
[]

# array_concat error
query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\.
select array_concat(1, 2);
Expand Down Expand Up @@ -2674,19 +2687,19 @@ select array_concat(make_array(), make_array(2, 3));
query ?
select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()));
----
[[1, 2], [3, 4]]
[[1, 2], [3, 4], []]

# array_concat scalar function #8 (with empty arrays)
query ?
select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()), make_array(make_array(), make_array()), make_array(make_array(5, 6), make_array(7, 8)));
----
[[1, 2], [3, 4], [5, 6], [7, 8]]
[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]]

# array_concat scalar function #9 (with empty arrays)
query ?
select array_concat(make_array(make_array()), make_array(make_array(1, 2), make_array(3, 4)));
----
[[1, 2], [3, 4]]
[[], [1, 2], [3, 4]]

# array_cat scalar function #10 (function alias `array_concat`)
query ??
Expand Down Expand Up @@ -3788,7 +3801,7 @@ select array_union([1,2,3], []);
[1, 2, 3]

query ?
select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)'));
select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)'));
----
[1, 2, 3]

Expand Down Expand Up @@ -3836,7 +3849,7 @@ select array_union([], []);
[]

query ?
select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)'));
select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)'));
----
[]

Expand All @@ -3847,7 +3860,7 @@ select array_union([[null]], []);
[[]]

query ?
select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)'));
select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)'));
----
[[]]

Expand All @@ -3858,7 +3871,7 @@ select array_union([null], [null]);
[]

query ?
select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))'));
select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([[null]], 'LargeList(List(Int64))'));
----
[[]]

Expand All @@ -3869,7 +3882,7 @@ select array_union(null, []);
[]

query ?
select array_union(null, arrow_cast([], 'LargeList(Null)'));
select array_union(null, arrow_cast([], 'LargeList(Int64)'));
----
[]

Expand Down Expand Up @@ -4106,14 +4119,14 @@ select cardinality(make_array()), cardinality(make_array(make_array()))
NULL 0

query II
select cardinality(arrow_cast(make_array(), 'LargeList(Null)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Null))'))
select cardinality(arrow_cast(make_array(), 'LargeList(Int64)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))'))
----
NULL 0

#TODO
#https://github.com/apache/datafusion/issues/9158
#query II
#select cardinality(arrow_cast(make_array(), 'FixedSizeList(1, Null)')), cardinality(arrow_cast(make_array(make_array()), 'FixedSizeList(1, List(Null))'))
#select cardinality(arrow_cast(make_array(), 'FixedSizeList(1, Null)')), cardinality(arrow_cast(make_array(make_array()), 'FixedSizeList(1, List(Int64))'))
#----
#NULL 0

Expand Down Expand Up @@ -4699,7 +4712,7 @@ select array_dims(make_array()), array_dims(make_array(make_array()))
NULL [1, 0]

query ??
select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))'))
select array_dims(arrow_cast(make_array(), 'LargeList(Int64)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))'))
----
NULL [1, 0]

Expand Down Expand Up @@ -4861,7 +4874,7 @@ select array_ndims(make_array()), array_ndims(make_array(make_array()))
1 2

query II
select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))'))
select array_ndims(arrow_cast(make_array(), 'LargeList(Int64)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))'))
----
1 2

Expand All @@ -4882,7 +4895,7 @@ select list_ndims(make_array()), list_ndims(make_array(make_array()))
1 2

query II
select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))'))
select list_ndims(arrow_cast(make_array(), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))'))
----
1 2

Expand Down Expand Up @@ -5500,7 +5513,7 @@ select array_intersect([], []);
[]

query ?
select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)'));
select array_intersect(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)'));
----
[]

Expand Down Expand Up @@ -5530,7 +5543,17 @@ select array_intersect([], null);
[]

query ?
select array_intersect(arrow_cast([], 'LargeList(Null)'), null);
select array_intersect([[1,2,3]], [[]]);
----
[]

query ?
select array_intersect([[null]], [[]]);
----
[]

query ?
select array_intersect(arrow_cast([], 'LargeList(Int64)'), null);
----
[]

Expand All @@ -5540,7 +5563,7 @@ select array_intersect(null, []);
NULL

query ?
select array_intersect(null, arrow_cast([], 'LargeList(Null)'));
select array_intersect(null, arrow_cast([], 'LargeList(Int64)'));
----
NULL

Expand Down Expand Up @@ -6196,7 +6219,7 @@ select empty(make_array());
true

query B
select empty(arrow_cast(make_array(), 'LargeList(Null)'));
select empty(arrow_cast(make_array(), 'LargeList(Int64)'));
----
true

Expand All @@ -6213,12 +6236,12 @@ select empty(make_array(NULL));
false

query B
select empty(arrow_cast(make_array(NULL), 'LargeList(Null)'));
select empty(arrow_cast(make_array(NULL), 'LargeList(Int64)'));
----
false

query B
select empty(arrow_cast(make_array(NULL), 'FixedSizeList(1, Null)'));
select empty(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)'));
----
false

Expand Down Expand Up @@ -6282,7 +6305,7 @@ select array_empty(make_array());
true

query B
select array_empty(arrow_cast(make_array(), 'LargeList(Null)'));
select array_empty(arrow_cast(make_array(), 'LargeList(Int64)'));
----
true

Expand All @@ -6293,7 +6316,7 @@ select array_empty(make_array(NULL));
false

query B
select array_empty(arrow_cast(make_array(NULL), 'LargeList(Null)'));
select array_empty(arrow_cast(make_array(NULL), 'LargeList(Int64)'));
----
false

Expand All @@ -6316,7 +6339,7 @@ select list_empty(make_array());
true

query B
select list_empty(arrow_cast(make_array(), 'LargeList(Null)'));
select list_empty(arrow_cast(make_array(), 'LargeList(Int64)'));
----
true

Expand All @@ -6327,7 +6350,7 @@ select list_empty(make_array(NULL));
false

query B
select list_empty(arrow_cast(make_array(NULL), 'LargeList(Null)'));
select list_empty(arrow_cast(make_array(NULL), 'LargeList(Int64)'));
----
false

Expand Down
Loading

0 comments on commit 053b53e

Please sign in to comment.