Skip to content

Commit

Permalink
union regr_slope and regr_intercept to regr_union
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoruet committed Sep 30, 2024
1 parent 13d4676 commit 29f6e52
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 347 deletions.

This file was deleted.

201 changes: 0 additions & 201 deletions be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function_regr_slope.h"

#include "vec/aggregate_functions/aggregate_function_regr_union.h"

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
Expand All @@ -26,37 +27,37 @@

namespace doris::vectorized {

template <typename T>
AggregateFunctionPtr type_dispatch_for_aggregate_function_regr_slope(
template <typename T, template<typename> class StatFunctionTemplate>
AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(
const DataTypes& argument_types, const bool& result_is_nullable, bool y_nullable_input,
bool x_nullable_input) {
using StatFunctionTemplate = RegrSlopeFuncTwoArg<T>;
if (y_nullable_input) {
if (x_nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSlopeSimple<StatFunctionTemplate, true, true>>(
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, true, true>>(
argument_types, result_is_nullable);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSlopeSimple<StatFunctionTemplate, true, false>>(
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, true, false>>(
argument_types, result_is_nullable);
}
} else {
if (x_nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSlopeSimple<StatFunctionTemplate, false, true>>(
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, false, true>>(
argument_types, result_is_nullable);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionRegrSlopeSimple<StatFunctionTemplate, false, false>>(
AggregateFunctionRegrSimple<StatFunctionTemplate<T>, false, false>>(
argument_types, result_is_nullable);
}
}
}

AggregateFunctionPtr create_aggregate_function_regr_slope(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
template <template<typename> class StatFunctionTemplate>
AggregateFunctionPtr create_aggregate_function_regr(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (argument_types.size() != 2) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments";
return nullptr;
Expand All @@ -71,19 +72,20 @@ AggregateFunctionPtr create_aggregate_function_regr_slope(const std::string& nam
WhichDataType y_type(remove_nullable(argument_types[0]));
WhichDataType x_type(remove_nullable(argument_types[1]));

#define DISPATCH(TYPE) \
if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_regr_slope<TYPE>( \
#define DISPATCH(TYPE) \
if (x_type.idx == TypeIndex::TYPE && y_type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_regr<TYPE, StatFunctionTemplate>( \
argument_types, result_is_nullable, y_nullable_input, x_nullable_input);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << "Unsupported input types " << argument_types[0]->get_name() << " and "
LOG(WARNING) << "unsupported input types " << argument_types[0]->get_name() << " and "
<< argument_types[1]->get_name() << " for aggregate function " << name;
return nullptr;
}

void register_aggregate_function_regr_slope(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("regr_slope", create_aggregate_function_regr_slope);
void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("regr_slope", create_aggregate_function_regr<RegrSlopeFunc>);
factory.register_function_both("regr_intercept", create_aggregate_function_regr<RegrInterceptFunc>);
}
} // namespace doris::vectorized
Loading

0 comments on commit 29f6e52

Please sign in to comment.