From 29f6e524f7ae45bab8c361b22da979dd253dc4ec Mon Sep 17 00:00:00 2001 From: yoruet <1559650411@qq.com> Date: Mon, 30 Sep 2024 15:40:08 +0800 Subject: [PATCH] union regr_slope and regr_intercept to regr_union --- .../aggregate_function_regr_intercept.cpp | 89 -------- .../aggregate_function_regr_intercept.h | 201 ------------------ ....cpp => aggregate_function_regr_union.cpp} | 44 ++-- ...lope.h => aggregate_function_regr_union.h} | 85 +++++--- .../aggregate_function_simple_factory.cpp | 6 +- 5 files changed, 78 insertions(+), 347 deletions(-) delete mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp delete mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h rename be/src/vec/aggregate_functions/{aggregate_function_regr_slope.cpp => aggregate_function_regr_union.cpp} (65%) rename be/src/vec/aggregate_functions/{aggregate_function_regr_slope.h => aggregate_function_regr_union.h} (74%) diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp deleted file mode 100644 index 3f2b5da0a7de74..00000000000000 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -#include "vec/aggregate_functions/aggregate_function_regr_intercept.h" - -#include "common/status.h" -#include "vec/aggregate_functions/aggregate_function.h" -#include "vec/aggregate_functions/aggregate_function_simple_factory.h" -#include "vec/aggregate_functions/helpers.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_nullable.h" - -namespace doris::vectorized { - -template -AggregateFunctionPtr type_dispatch_for_aggregate_function_regr_intercept( - const DataTypes& argument_types, const bool& result_is_nullable, bool y_column_nullable, - bool x_column_nullable) { - using StatFunctionTemplate = RegrInterceptFuncTwoArg; - if (y_column_nullable) { - if (x_column_nullable) { - return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrInterceptSimple>( - argument_types, result_is_nullable); - } else { - return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrInterceptSimple>( - argument_types, result_is_nullable); - } - } else { - if (x_column_nullable) { - return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrInterceptSimple>( - argument_types, result_is_nullable); - } else { - return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrInterceptSimple>( - argument_types, result_is_nullable); - } - } -} - -AggregateFunctionPtr create_aggregate_function_regr_intercept(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { - if (argument_types.size() != 2) { - LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments"; - return nullptr; - } - if (!result_is_nullable) { - LOG(WARNING) << "aggregate function " << name << " requires nullable result type"; - return nullptr; - } - - bool y_nullable_input = argument_types[0]->is_nullable(); - bool x_nullable_input = argument_types[1]->is_nullable(); - WhichDataType y_type(remove_nullable(argument_types[0])); - WhichDataType x_type(remove_nullable(argument_types[1])); - -#define DISPATCH(TYPE) \ - if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \ - return type_dispatch_for_aggregate_function_regr_intercept( \ - argument_types, result_is_nullable, y_nullable_input, x_nullable_input); - FOR_NUMERIC_TYPES(DISPATCH) -#undef DISPATCH - - LOG(WARNING) << "Unsupported input types " << argument_types[0]->get_name() << " and " - << argument_types[1]->get_name() << " for aggregate function " << name; - return nullptr; -} - -void register_aggregate_function_regr_intercept(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("regr_intercept", create_aggregate_function_regr_intercept); -} -} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h deleted file mode 100644 index c72e8ca03e4247..00000000000000 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h +++ /dev/null @@ -1,201 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include - -#include "common/exception.h" -#include "common/status.h" -#include "vec/aggregate_functions/aggregate_function.h" -#include "vec/columns/column_nullable.h" -#include "vec/columns/column_vector.h" -#include "vec/common/assert_cast.h" -#include "vec/core/field.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_nullable.h" -#include "vec/data_types/data_type_number.h" -#include "vec/io/io_helper.h" -namespace doris::vectorized { - -template -struct AggregateFunctionRegrInterceptData { - UInt64 count = 0; - Float64 sum_x {}; - Float64 sum_y {}; - Float64 sum_of_x_mul_y {}; - Float64 sum_of_x_squared {}; - - void write(BufferWritable& buf) const { - write_binary(sum_x, buf); - write_binary(sum_y, buf); - write_binary(sum_of_x_mul_y, buf); - write_binary(sum_of_x_squared, buf); - write_binary(count, buf); - } - - void read(BufferReadable& buf) { - read_binary(sum_x, buf); - read_binary(sum_y, buf); - read_binary(sum_of_x_mul_y, buf); - read_binary(sum_of_x_squared, buf); - read_binary(count, buf); - } - - void reset() { - sum_x = {}; - sum_y = {}; - sum_of_x_mul_y = {}; - sum_of_x_squared = {}; - count = 0; - } - - Float64 get_intercept_result() const { - Float64 denominator = count * sum_of_x_squared - sum_x * sum_x; - if (count < 2 || denominator == 0.0) { - return std::numeric_limits::quiet_NaN(); - } - Float64 slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; - return (sum_y - slope * sum_x) / count; - } - - void merge(const AggregateFunctionRegrInterceptData& rhs) { - if (rhs.count == 0) { - return; - } - sum_x += rhs.sum_x; - sum_y += rhs.sum_y; - sum_of_x_mul_y += rhs.sum_of_x_mul_y; - sum_of_x_squared += rhs.sum_of_x_squared; - count += rhs.count; - } - - void add(T value_y, T value_x) { - sum_x += value_x; - sum_y += value_y; - sum_of_x_mul_y += value_x * value_y; - sum_of_x_squared += value_x * value_x; - count += 1; - } -}; - -template -struct RegrInterceptFuncTwoArg { - using Type = T; - using Data = AggregateFunctionRegrInterceptData; -}; - -template -class AggregateFunctionRegrInterceptSimple - : public IAggregateFunctionDataHelper< - typename StatFunc::Data, - AggregateFunctionRegrInterceptSimple> { -public: - using Type = typename StatFunc::Type; - using XInputCol = ColumnVector; - using YInputCol = ColumnVector; - using ResultCol = ColumnVector; - - explicit AggregateFunctionRegrInterceptSimple(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper< - typename StatFunc::Data, - AggregateFunctionRegrInterceptSimple>( - argument_types_) { - DCHECK(!argument_types_.empty()); - } - - String get_name() const override { return "regr_intercept"; } - - DataTypePtr get_return_type() const override { - return make_nullable(std::make_shared()); - } - - void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, - Arena*) const override { - bool y_null = false; - bool x_null = false; - const YInputCol* y_nested_column = nullptr; - const XInputCol* x_nested_column = nullptr; - - if constexpr (y_nullable) { - const ColumnNullable& y_column_nullable = - assert_cast(*columns[0]); - y_null = y_column_nullable.is_null_at(row_num); - y_nested_column = assert_cast( - y_column_nullable.get_nested_column_ptr().get()); - } else { - y_nested_column = assert_cast( - (*columns[0]).get_ptr().get()); - } - - if constexpr (x_nullable) { - const ColumnNullable& x_column_nullable = - assert_cast(*columns[1]); - x_null = x_column_nullable.is_null_at(row_num); - x_nested_column = assert_cast( - x_column_nullable.get_nested_column_ptr().get()); - } else { - x_nested_column = assert_cast( - (*columns[1]).get_ptr().get()); - } - - if (x_null || y_null) { - return; - } - - Type y_value = y_nested_column->get_data()[row_num]; - Type x_value = x_nested_column->get_data()[row_num]; - - this->data(place).add(y_value, x_value); - } - - void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } - - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, - Arena*) const override { - this->data(place).merge(this->data(rhs)); - } - - void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { - this->data(place).write(buf); - } - - void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, - Arena*) const override { - this->data(place).read(buf); - } - - void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - const auto& data = this->data(place); - auto& dst_column_with_nullable = assert_cast(to); - auto& dst_column = assert_cast(dst_column_with_nullable.get_nested_column()); - Float64 intercept = data.get_intercept_result(); - if (std::isnan(intercept)) { - dst_column_with_nullable.get_null_map_data().push_back(1); - dst_column.insert_default(); - } else { - dst_column_with_nullable.get_null_map_data().push_back(0); - dst_column.get_data().push_back(intercept); - } - } -}; - -} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp similarity index 65% rename from be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp rename to be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp index d3a148a81b5d60..ee255bc6cf3cb6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp @@ -1,20 +1,21 @@ // Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file +// or more contributor license agreements. See the NOTICE file // distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file +// regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at +// with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the +// KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -#include "vec/aggregate_functions/aggregate_function_regr_slope.h" + +#include "vec/aggregate_functions/aggregate_function_regr_union.h" #include "common/status.h" #include "vec/aggregate_functions/aggregate_function.h" @@ -26,37 +27,37 @@ namespace doris::vectorized { -template -AggregateFunctionPtr type_dispatch_for_aggregate_function_regr_slope( +template class StatFunctionTemplate> +AggregateFunctionPtr type_dispatch_for_aggregate_function_regr( const DataTypes& argument_types, const bool& result_is_nullable, bool y_nullable_input, bool x_nullable_input) { - using StatFunctionTemplate = RegrSlopeFuncTwoArg; if (y_nullable_input) { if (x_nullable_input) { return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrSlopeSimple>( + AggregateFunctionRegrSimple, true, true>>( argument_types, result_is_nullable); } else { return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrSlopeSimple>( + AggregateFunctionRegrSimple, true, false>>( argument_types, result_is_nullable); } } else { if (x_nullable_input) { return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrSlopeSimple>( + AggregateFunctionRegrSimple, false, true>>( argument_types, result_is_nullable); } else { return creator_without_type::create_ignore_nullable< - AggregateFunctionRegrSlopeSimple>( + AggregateFunctionRegrSimple, false, false>>( argument_types, result_is_nullable); } } } -AggregateFunctionPtr create_aggregate_function_regr_slope(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +template class StatFunctionTemplate> +AggregateFunctionPtr create_aggregate_function_regr(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { if (argument_types.size() != 2) { LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments"; return nullptr; @@ -71,19 +72,20 @@ AggregateFunctionPtr create_aggregate_function_regr_slope(const std::string& nam WhichDataType y_type(remove_nullable(argument_types[0])); WhichDataType x_type(remove_nullable(argument_types[1])); -#define DISPATCH(TYPE) \ - if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \ - return type_dispatch_for_aggregate_function_regr_slope( \ +#define DISPATCH(TYPE) \ + if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \ + return type_dispatch_for_aggregate_function_regr( \ argument_types, result_is_nullable, y_nullable_input, x_nullable_input); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - LOG(WARNING) << "Unsupported input types " << argument_types[0]->get_name() << " and " + LOG(WARNING) << "unsupported input types " << argument_types[0]->get_name() << " and " << argument_types[1]->get_name() << " for aggregate function " << name; return nullptr; } -void register_aggregate_function_regr_slope(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("regr_slope", create_aggregate_function_regr_slope); +void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("regr_slope", create_aggregate_function_regr); + factory.register_function_both("regr_intercept", create_aggregate_function_regr); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h similarity index 74% rename from be/src/vec/aggregate_functions/aggregate_function_regr_slope.h rename to be/src/vec/aggregate_functions/aggregate_function_regr_union.h index e0f0db00db3f75..0296b2f800a245 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -34,15 +34,16 @@ #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" #include "vec/io/io_helper.h" + namespace doris::vectorized { template -struct AggregateFunctionRegrSlopeData { +struct AggregateFunctionRegrData { UInt64 count = 0; - Float64 sum_x {}; - Float64 sum_y {}; - Float64 sum_of_x_mul_y {}; - Float64 sum_of_x_squared {}; + Float64 sum_x{}; + Float64 sum_y{}; + Float64 sum_of_x_mul_y{}; + Float64 sum_of_x_squared{}; void write(BufferWritable& buf) const { write_binary(sum_x, buf); @@ -68,16 +69,7 @@ struct AggregateFunctionRegrSlopeData { count = 0; } - Float64 get_slope_result() const { - Float64 denominator = count * sum_of_x_squared - sum_x * sum_x; - if (count < 2 || denominator == 0.0) { - return std::numeric_limits::quiet_NaN(); - } - Float64 slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; - return slope; - } - - void merge(const AggregateFunctionRegrSlopeData& rhs) { + void merge(const AggregateFunctionRegrData& rhs) { if (rhs.count == 0) { return; } @@ -98,31 +90,60 @@ struct AggregateFunctionRegrSlopeData { }; template -struct RegrSlopeFuncTwoArg { +struct RegrSlopeFunc { using Type = T; - using Data = AggregateFunctionRegrSlopeData; + using Data = AggregateFunctionRegrData; + static constexpr const char* name = "regr_slope"; + + template + static Float64 get_result(const Data& data) { + Float64 denominator = data.count * data.sum_of_x_squared - data.sum_x * data.sum_x; + if (data.count < 2 || denominator == 0.0) { + return std::numeric_limits::quiet_NaN(); + } + Float64 slope = (data.count * data.sum_of_x_mul_y - data.sum_x * data.sum_y) / denominator; + return slope; + } }; -template -class AggregateFunctionRegrSlopeSimple +template +struct RegrInterceptFunc { + using Type = T; + using Data = AggregateFunctionRegrData; + static constexpr const char* name = "regr_intercept"; + + template + static Float64 get_result(const Data& data) { + Float64 denominator = data.count * data.sum_of_x_squared - data.sum_x * data.sum_x; + if (data.count < 2 || denominator == 0.0) { + return std::numeric_limits::quiet_NaN(); + } + Float64 slope = (data.count * data.sum_of_x_mul_y - data.sum_x * data.sum_y) / denominator; + Float64 intercept = (data.sum_y - slope * data.sum_x) / data.count; + return intercept; + } +}; + +template +class AggregateFunctionRegrSimple : public IAggregateFunctionDataHelper< - typename StatFunc::Data, - AggregateFunctionRegrSlopeSimple> { + typename RegrFunc::Data, + AggregateFunctionRegrSimple> { public: - using Type = typename StatFunc::Type; + using Type = typename RegrFunc::Type; using XInputCol = ColumnVector; using YInputCol = ColumnVector; using ResultCol = ColumnVector; - explicit AggregateFunctionRegrSlopeSimple(const DataTypes& argument_types_) + explicit AggregateFunctionRegrSimple(const DataTypes& argument_types_) : IAggregateFunctionDataHelper< - typename StatFunc::Data, - AggregateFunctionRegrSlopeSimple>( + typename RegrFunc::Data, + AggregateFunctionRegrSimple>( argument_types_) { DCHECK(!argument_types_.empty()); } - String get_name() const override { return "regr_slope"; } + String get_name() const override { return RegrFunc::name; } DataTypePtr get_return_type() const override { return make_nullable(std::make_shared()); @@ -145,6 +166,7 @@ class AggregateFunctionRegrSlopeSimple y_nested_column = assert_cast( (*columns[0]).get_ptr().get()); } + if constexpr (x_nullable) { const ColumnNullable& x_column_nullable = assert_cast(*columns[1]); @@ -160,8 +182,8 @@ class AggregateFunctionRegrSlopeSimple return; } - Type x_value = x_nested_column->get_data()[row_num]; Type y_value = y_nested_column->get_data()[row_num]; + Type x_value = x_nested_column->get_data()[row_num]; this->data(place).add(y_value, x_value); } @@ -186,15 +208,14 @@ class AggregateFunctionRegrSlopeSimple const auto& data = this->data(place); auto& dst_column_with_nullable = assert_cast(to); auto& dst_column = assert_cast(dst_column_with_nullable.get_nested_column()); - Float64 slope = data.get_slope_result(); - if (std::isnan(slope)) { + Float64 result = RegrFunc::get_result(data); + if (std::isnan(result)) { dst_column_with_nullable.get_null_map_data().push_back(1); dst_column.insert_default(); } else { dst_column_with_nullable.get_null_map_data().push_back(0); - dst_column.get_data().push_back(slope); + dst_column.get_data().push_back(result); } } }; - -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 58b4657841e006..eb8639d5908986 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -56,8 +56,7 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact void register_aggregate_function_percentile_old(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel_old(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_regr_intercept(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_regr_slope(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory); @@ -102,8 +101,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_percentile_approx(instance); register_aggregate_function_window_funnel(instance); register_aggregate_function_window_funnel_old(instance); - register_aggregate_function_regr_intercept(instance); - register_aggregate_function_regr_slope(instance); + register_aggregate_function_regr_union(instance); register_aggregate_function_retention(instance); register_aggregate_function_orthogonal_bitmap(instance); register_aggregate_function_collect_list(instance);