Skip to content

Commit

Permalink
using sub class in regr_slope and regr_intercept
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoruet committed Sep 30, 2024
1 parent 29f6e52 commit 43c58df
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
15 changes: 8 additions & 7 deletions be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@

namespace doris::vectorized {

template <typename T, template<typename> class StatFunctionTemplate>
AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(
const DataTypes& argument_types, const bool& result_is_nullable, bool y_nullable_input,
bool x_nullable_input) {
template <typename T, template <typename> class StatFunctionTemplate>
AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(const DataTypes& argument_types,
const bool& result_is_nullable,
bool y_nullable_input,
bool x_nullable_input) {
if (y_nullable_input) {
if (x_nullable_input) {
return creator_without_type::create_ignore_nullable<
Expand All @@ -54,7 +55,7 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(
}
}

template <template<typename> class StatFunctionTemplate>
template <template <typename> class StatFunctionTemplate>
AggregateFunctionPtr create_aggregate_function_regr(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
Expand All @@ -72,8 +73,8 @@ AggregateFunctionPtr create_aggregate_function_regr(const std::string& name,
WhichDataType y_type(remove_nullable(argument_types[0]));
WhichDataType x_type(remove_nullable(argument_types[1]));

#define DISPATCH(TYPE) \
if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \
#define DISPATCH(TYPE) \
if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_regr<TYPE, StatFunctionTemplate>( \
argument_types, result_is_nullable, y_nullable_input, x_nullable_input);
FOR_NUMERIC_TYPES(DISPATCH)
Expand Down
40 changes: 18 additions & 22 deletions be/src/vec/aggregate_functions/aggregate_function_regr_union.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ namespace doris::vectorized {
template <typename T>
struct AggregateFunctionRegrData {
UInt64 count = 0;
Float64 sum_x{};
Float64 sum_y{};
Float64 sum_of_x_mul_y{};
Float64 sum_of_x_squared{};
Float64 sum_x {};
Float64 sum_y {};
Float64 sum_of_x_mul_y {};
Float64 sum_of_x_squared {};

void write(BufferWritable& buf) const {
write_binary(sum_x, buf);
Expand Down Expand Up @@ -90,44 +90,40 @@ struct AggregateFunctionRegrData {
};

template <typename T>
struct RegrSlopeFunc {
struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
using Type = T;
using Data = AggregateFunctionRegrData<Type>;
static constexpr const char* name = "regr_slope";

template <typename Data>
static Float64 get_result(const Data& data) {
Float64 denominator = data.count * data.sum_of_x_squared - data.sum_x * data.sum_x;
if (data.count < 2 || denominator == 0.0) {
Float64 get_result() const {
Float64 denominator = this->count * this->sum_of_x_squared - this->sum_x * this->sum_x;
if (this->count < 2 || denominator == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
Float64 slope = (data.count * data.sum_of_x_mul_y - data.sum_x * data.sum_y) / denominator;
Float64 slope = (this->count * this->sum_of_x_mul_y - this->sum_x * this->sum_y) / denominator;
return slope;
}
};

template <typename T>
struct RegrInterceptFunc {
struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
using Type = T;
using Data = AggregateFunctionRegrData<Type>;
static constexpr const char* name = "regr_intercept";

template <typename Data>
static Float64 get_result(const Data& data) {
Float64 denominator = data.count * data.sum_of_x_squared - data.sum_x * data.sum_x;
if (data.count < 2 || denominator == 0.0) {
Float64 get_result() const {
Float64 denominator = this->count * this->sum_of_x_squared - this->sum_x * this->sum_x;
if (this->count < 2 || denominator == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
Float64 slope = (data.count * data.sum_of_x_mul_y - data.sum_x * data.sum_y) / denominator;
Float64 intercept = (data.sum_y - slope * data.sum_x) / data.count;
Float64 slope = (this->count * this->sum_of_x_mul_y - this->sum_x * this->sum_y) / denominator;
Float64 intercept = (this->sum_y - slope * this->sum_x) / this->count;
return intercept;
}
};

template <typename RegrFunc, bool y_nullable, bool x_nullable>
class AggregateFunctionRegrSimple
: public IAggregateFunctionDataHelper<
typename RegrFunc::Data,
RegrFunc,
AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>> {
public:
using Type = typename RegrFunc::Type;
Expand All @@ -137,7 +133,7 @@ class AggregateFunctionRegrSimple

explicit AggregateFunctionRegrSimple(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
typename RegrFunc::Data,
RegrFunc,
AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>>(
argument_types_) {
DCHECK(!argument_types_.empty());
Expand Down Expand Up @@ -208,7 +204,7 @@ class AggregateFunctionRegrSimple
const auto& data = this->data(place);
auto& dst_column_with_nullable = assert_cast<ColumnNullable&>(to);
auto& dst_column = assert_cast<ResultCol&>(dst_column_with_nullable.get_nested_column());
Float64 result = RegrFunc::get_result(data);
Float64 result = data.get_result();
if (std::isnan(result)) {
dst_column_with_nullable.get_null_map_data().push_back(1);
dst_column.insert_default();
Expand Down

0 comments on commit 43c58df

Please sign in to comment.