diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp b/be/src/pipeline/exec/aggregation_sink_operator.cpp index 260a599a947a0de..7d1c5ea09ee7e3f 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp @@ -742,7 +742,7 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + tnode.agg_node.grouping_exprs.empty(), &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/pipeline/exec/aggregation_source_operator.cpp b/be/src/pipeline/exec/aggregation_source_operator.cpp index fe03eba41029553..e43aebb5a333013 100644 --- a/be/src/pipeline/exec/aggregation_source_operator.cpp +++ b/be/src/pipeline/exec/aggregation_source_operator.cpp @@ -392,6 +392,18 @@ Status AggLocalState::_get_without_key_result(RuntimeState* state, vectorized::B for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { data_types[i] = shared_state.aggregate_evaluators[i]->function()->get_return_type(); columns[i] = data_types[i]->create_column(); + +#ifndef NDEBUG + if (shared_state.aggregate_evaluators[i]->function()->get_nullable_property() == + vectorized::NullablePropertyEnum::ALWAYS_NULLABLE) { + DCHECK(data_types[i]->is_nullable()) << fmt::format( + "Query {}, AlwaysNullable aggregate function {} should return ColumnNullable, " + "but get {}", + print_id(state->query_id()), + shared_state.aggregate_evaluators[i]->function()->get_name(), + data_types[i]->get_name()); + } +#endif } for (int i = 0; i < shared_state.aggregate_evaluators.size(); ++i) { diff --git a/be/src/pipeline/exec/analytic_source_operator.cpp b/be/src/pipeline/exec/analytic_source_operator.cpp index b9e48727656e056..edb3e8aae5d392c 100644 --- a/be/src/pipeline/exec/analytic_source_operator.cpp +++ b/be/src/pipeline/exec/analytic_source_operator.cpp @@ -499,11 +499,13 @@ Status AnalyticSourceOperatorX::init(const TPlanNode& tnode, RuntimeState* state RETURN_IF_ERROR(OperatorX::init(tnode, state)); const TAnalyticNode& analytic_node = tnode.analytic_node; size_t agg_size = analytic_node.analytic_functions.size(); - for (int i = 0; i < agg_size; ++i) { vectorized::AggFnEvaluator* evaluator = nullptr; + // Window function treats all NullableAggregateFunction as AlwaysNullable. + // Its behavior is same with executed without group by key. + // https://github.com/apache/doris/pull/40693 RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( - _pool, analytic_node.analytic_functions[i], {}, &evaluator)); + _pool, analytic_node.analytic_functions[i], {}, /*wihout_key*/ true, &evaluator)); _agg_functions.emplace_back(evaluator); } diff --git a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp index 70b73225f060e82..a917305edbb9b31 100644 --- a/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/distinct_streaming_aggregation_operator.cpp @@ -361,7 +361,7 @@ Status DistinctStreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + tnode.agg_node.grouping_exprs.empty(), &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index dfbe42c637ea568..8afbbc4ed2927c4 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -1156,7 +1156,7 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) RETURN_IF_ERROR(vectorized::AggFnEvaluator::create( _pool, tnode.agg_node.aggregate_functions[i], tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy, - &evaluator)); + tnode.agg_node.grouping_exprs.empty(), &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 05f1bd2a602c685..672345d750fb15c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -20,6 +20,11 @@ #pragma once +#include +#include + +#include "common/exception.h" +#include "common/status.h" #include "util/defer_op.h" #include "vec/columns/column_complex.h" #include "vec/columns/column_string.h" @@ -30,6 +35,7 @@ #include "vec/core/column_numbers.h" #include "vec/core/field.h" #include "vec/core/types.h" +#include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_string.h" namespace doris::vectorized { @@ -62,6 +68,126 @@ using ConstAggregateDataPtr = const char*; } \ } while (0) +enum struct NullablePropertyEnum : UInt8 { + ALWAYS_NOT_NULLABLE = 0, + ALWAYS_NULLABLE = 1, + PROPOGATE_NULLABLE = 2, + NORMAL_NULLABLE = 3, + SKIP_NULLABLE_PROPERTY_CHECK = 4, +}; + +struct AlwaysNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable, + std::string& msg) { + bool valid = result_type_with_nullable->is_nullable(); + if (!valid) { + msg = "AlwaysNullable property is not satisfied: result type must be Nullable"; + } + return valid; + } +}; + +struct AlwaysNotNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable, + std::string& msg) { + bool valid = !result_type_with_nullable->is_nullable(); + if (!valid) { + msg = "AlwaysNotNullable property is not satisfied: result type must be NotNullable"; + } + return valid; + } +}; + +// PropograteNullable is deprecated after this pr: https://github.com/apache/doris/pull/37330 +// No more PropograteNullable aggregate function, use NullableAggregateFunction instead +// We keep this struct since this on branch 2.1.x, many aggregate functions on FE are still PropograteNullable. +struct PropograteNullable { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable, + std::string& msg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "PropograteNullable should not used after version 2.1.x"); + } +}; + +// For some aggregate functions, we can skip the nullable property check. +// Maybe its nullable type is too complicated. +struct SkipNullablePropertyCheck { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable, + std::string& msg) { + return true; + } +}; + +struct NullableAggregateFunction { + static bool is_valid_nullable_property(const bool without_key, + const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable, + std::string& msg) { + if (std::any_of(argument_types_with_nullable.begin(), argument_types_with_nullable.end(), + [](const DataTypePtr& type) { return type->is_nullable(); }) && + result_type_with_nullable->is_nullable()) { + bool valid = result_type_with_nullable->is_nullable(); + if (!valid) { + std::string arg_type = ""; + for (const auto& type : argument_types_with_nullable) { + arg_type += type->get_name() + ", "; + } + std::string result_type = result_type_with_nullable->get_name(); + msg = fmt::format( + "NullableAggregateFunction property is not satisfied, input: {}, result: " + "{}", + arg_type, result_type); + } + // One of input arguments is nullable, the result must be nullable. + return valid; + } else { + // All column is not nullable, the result can be nullable or not. + // Depends on whether executed with group by. + if (without_key) { + bool valid = result_type_with_nullable->is_nullable(); + if (!valid) { + std::string arg_type = ""; + for (const auto& type : argument_types_with_nullable) { + arg_type += type->get_name() + ", "; + } + std::string result_type = result_type_with_nullable->get_name(); + msg = fmt::format( + "NullableAggregateFunction property is not satisfied, input: {}, " + "result: " + "{}, with group by {}", + arg_type, result_type, !without_key); + } + // If without key, means agg is executed without group by, the result must be nullable. + return valid; + } else { + bool valid = !result_type_with_nullable->is_nullable(); + if (!valid) { + std::string arg_type = ""; + for (const auto& type : argument_types_with_nullable) { + arg_type += type->get_name() + ", "; + } + std::string result_type = result_type_with_nullable->get_name(); + msg = fmt::format( + "NullableAggregateFunction property is not satisfied, input: {}, " + "result: " + "{}, with group by {}", + arg_type, result_type, !without_key); + } + // If not without key, means agg is executed with group by, the result must be not nullable. + return valid; + } + } + } +}; + /** Aggregate functions interface. * Instances of classes with this interface do not contain the data itself for aggregation, * but contain only metadata (description) of the aggregate function, @@ -219,6 +345,10 @@ class IAggregateFunction { virtual AggregateFunctionPtr transmit_to_stable() { return nullptr; } + /// Verify function signature + virtual Status verify_result_type(const bool without_key, const DataTypes& argument_types, + const DataTypePtr result_type) const = 0; + protected: DataTypes argument_types; int version {}; @@ -491,6 +621,41 @@ class IAggregateFunctionHelper : public IAggregateFunction { arena); assert_cast(this)->merge(place, rhs, arena); } + + Status verify_result_type(const bool without_key, const DataTypes& argument_types_with_nullable, + const DataTypePtr result_type_with_nullable) const override { + DataTypePtr function_result_type = assert_cast(this)->get_return_type(); + + if (function_result_type->equals(*result_type_with_nullable)) { + return Status::OK(); + } + + if (!remove_nullable(function_result_type) + ->equals(*remove_nullable(result_type_with_nullable))) { + return Status::InternalError( + "Result type is not matched, planner expect {}, but get {}, wihout group by: " + "{}", + result_type_with_nullable->get_name(), function_result_type->get_name(), + without_key); + } + + if (without_key == true) { + if (result_type_with_nullable->is_nullable()) { + // This branch is decicated for NullableAggregateFunction. + // When they are executed without group by key, the result from planner will be AlwaysNullable + // since Planer does not know whether there are any invalid input at runtime, if so, the result + // should be Null, so the result type must be nullable. + // Backend will wrap a ColumnNullable in this situation. For example: AggLocalState::_get_without_key_result + return Status::OK(); + } + } + + // Executed with group by key, result type must be exactly same with the return type from Planner. + return Status::InternalError( + "Result type is not matched, planner expect {}, but get {}, wihout group by: {}", + result_type_with_nullable->get_name(), function_result_type->get_name(), + without_key); + } }; /// Implements several methods for manipulation with data. T - type of structure with data for aggregation. diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index c96d84db16c89c7..0ca356c1590131f 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -63,9 +63,10 @@ AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types, argument_types, return_type); } -AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) +AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), + _without_key(without_key), _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)) { bool nullable = true; if (desc.__isset.is_nullable) { @@ -83,8 +84,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) } Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, - AggFnEvaluator** result) { - *result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0]).release()); + const bool without_key, AggFnEvaluator** result) { + *result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key).release()); auto& agg_fn_evaluator = *result; int node_idx = 0; for (int i = 0; i < desc.nodes[0].num_children; ++i) { @@ -213,6 +214,9 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, _function = transform_to_sort_agg_function(_function, _argument_types_with_sort, _sort_description, state); } + + RETURN_IF_ERROR(_function->verify_result_type(_without_key, argument_types, _data_type)); + _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name); return Status::OK(); } @@ -320,6 +324,7 @@ AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) { AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state) : _fn(evaluator._fn), _is_merge(evaluator._is_merge), + _without_key(evaluator._without_key), _argument_types_with_sort(evaluator._argument_types_with_sort), _real_argument_types(evaluator._real_argument_types), _return_type(evaluator._return_type), diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 7dcd1b3e02bb474..30983795e42b72b 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -50,7 +50,7 @@ class AggFnEvaluator { public: static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, - AggFnEvaluator** result); + const bool without_key, AggFnEvaluator** result); Status prepare(RuntimeState* state, const RowDescriptor& desc, const SlotDescriptor* intermediate_slot_desc, @@ -109,8 +109,12 @@ class AggFnEvaluator { const TFunction _fn; const bool _is_merge; + // We need this flag to distinguish between the two types of aggregation functions: + // 1. executed without group by key (agg function used with window function is also regarded as this type) + // 2. executed with group by key + const bool _without_key; - AggFnEvaluator(const TExprNode& desc); + AggFnEvaluator(const TExprNode& desc, const bool without_key); AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state); Status _calc_argument_columns(Block* block);