Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adds boost root finders with reverse mode specializations #2720

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c33ef74
Adds scalar root finding function
SteveBronder Apr 28, 2022
b015dae
cleanup
SteveBronder Apr 28, 2022
bd9fc69
cleanup and have the functors pass as a set of tuples
SteveBronder Apr 29, 2022
9b2b477
adds promotion logic to root_solver_tol to avoid boost errors
SteveBronder Apr 29, 2022
d729001
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 29, 2022
f2f8c51
fix header includes
SteveBronder Apr 29, 2022
04dbda8
fix header includes
SteveBronder Apr 29, 2022
c0d4a81
start working on beta test
SteveBronder May 3, 2022
b08f3ad
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 3, 2022
33e0b50
update rev test
SteveBronder May 10, 2022
f468582
Merge commit '2e45ac5788d650f1f2bf05c6bc3df9c3f0ab69b5' into HEAD
yashikno May 10, 2022
78f84c5
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 10, 2022
5bff38b
update test
SteveBronder May 11, 2022
f5c762b
Merge branch 'feature/root-finder' of github.com:stan-dev/math into f…
SteveBronder May 11, 2022
26ca3fc
add more rev test with explicit derivative
SteveBronder May 12, 2022
4c6956f
Merge commit '43ec11b55f3f6d35d8962ec37e879038d19321dc' into HEAD
yashikno May 12, 2022
91a748c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 12, 2022
ee38928
update root finder to pass the function F at compile time
SteveBronder May 23, 2022
344a450
fix sign
SteveBronder May 24, 2022
ef8d4e0
fix sign
SteveBronder May 24, 2022
4932068
Merge branch 'feature/root-finder' of github.com:stan-dev/math into f…
SteveBronder May 24, 2022
19588e4
turn on fifth root mix check
SteveBronder May 25, 2022
11c4a03
Merge remote-tracking branch 'origin/develop' into feature/root-finder
SteveBronder Jun 3, 2022
13a8886
fixup tests to pass
SteveBronder Jun 3, 2022
8cd506b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jun 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <stan/math/fwd/fun/fmax.hpp>
#include <stan/math/fwd/fun/fmin.hpp>
#include <stan/math/fwd/fun/fmod.hpp>
#include <stan/math/fwd/fun/frexp.hpp>
#include <stan/math/fwd/fun/gamma_p.hpp>
#include <stan/math/fwd/fun/gamma_q.hpp>
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
Expand Down Expand Up @@ -103,6 +104,7 @@
#include <stan/math/fwd/fun/read_fvar.hpp>
#include <stan/math/fwd/fun/rising_factorial.hpp>
#include <stan/math/fwd/fun/round.hpp>
#include <stan/math/fwd/fun/sign.hpp>
#include <stan/math/fwd/fun/sin.hpp>
#include <stan/math/fwd/fun/sinh.hpp>
#include <stan/math/fwd/fun/softmax.hpp>
Expand Down
1 change: 1 addition & 0 deletions stan/math/fwd/fun/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stan/math/prim/fun/abs.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <complex>

namespace stan {
Expand Down
17 changes: 17 additions & 0 deletions stan/math/fwd/fun/frexp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef STAN_MATH_FWD_FUN_FREXP_HPP
#define STAN_MATH_FWD_FUN_FREXP_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>

namespace stan {
namespace math {

template <typename T>
inline auto frexp(const fvar<T>& x, int* exponent) noexcept {
return std::frexp(value_of_rec(x), exponent);
}
} // namespace math
} // namespace stan
#endif
19 changes: 19 additions & 0 deletions stan/math/fwd/fun/sign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef STAN_MATH_FWD_FUN_SIGN_HPP
#define STAN_MATH_FWD_FUN_SIGN_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>

namespace stan {
namespace math {

template <typename T>
inline auto sign(const fvar<T>& x) {
double z = value_of_rec(x);
return (z == 0) ? 0 : z < 0 ? -1 : 1;
}

} // namespace math
} // namespace stan
#endif
2 changes: 1 addition & 1 deletion stan/math/prim/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
#include <stan/math/prim/functor/operands_and_partials.hpp>
#include <stan/math/prim/functor/reduce_sum.hpp>
#include <stan/math/prim/functor/reduce_sum_static.hpp>

#include <stan/math/prim/functor/root_finder.hpp>
#endif
195 changes: 195 additions & 0 deletions stan/math/prim/functor/root_finder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP
#define STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_bounded.hpp>
#include <stan/math/prim/err/check_positive.hpp>
#include <stan/math/prim/functor/apply.hpp>
#include <boost/math/tools/roots.hpp>
#include <tuple>
#include <utility>

namespace stan {
namespace math {
namespace internal {
template <bool ReturnDerivs, typename FRootFunc, typename... Args,
std::enable_if_t<ReturnDerivs>* = nullptr>
inline auto make_root_func(Args&&... args) {
return [&args...](auto&& x) {
return std::decay_t<FRootFunc>::template run<ReturnDerivs>(x, args...);
};
}

template <bool ReturnDerivs, typename FRootFunc,
std::enable_if_t<!ReturnDerivs>* = nullptr>
inline auto make_root_func() {
return [](auto&&... args) {
return std::decay_t<FRootFunc>::template run<ReturnDerivs>(args...);
};
}

struct NewtonRootSolver {
template <typename... Types>
static inline auto run(Types&&... args) {
return boost::math::tools::newton_raphson_iterate(
std::forward<Types>(args)...);
}
};

struct HalleyRootSolver {
template <typename... Types>
static inline auto run(Types&&... args) {
return boost::math::tools::halley_iterate(std::forward<Types>(args)...);
}
};

struct SchroderRootSolver {
template <typename... Types>
static inline auto run(Types&&... args) {
return boost::math::tools::schroder_iterate(std::forward<Types>(args)...);
}
};

} // namespace internal

/**
* Solve for root using Boost's Halley method
* @tparam FRootFunc A struct or class with a static function called `run`.
* The structs `run` function must have a boolean template parameter that
* when `true` returns a tuple containing the function result and the
* derivatives needed for the root finder. When the boolean template parameter
* is `false` the function should return a single value containing the function
* result.
* @tparam SolverFun One of the three struct types used to call the root solver.
* (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`).
* @tparam GuessScalar Scalar type
* @tparam MinScalar Scalar type
* @tparam MaxScalar Scalar type
* @tparam Types Arg types to pass to functors in `f_tuple`
* @param guess An initial guess at the root value
* @param min The minimum possible value for the result, this is used as an
* initial lower bracket
* @param max The maximum possible value for the result, this is used as an
* initial upper bracket
* @param digits The desired number of binary digits precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indicate that digits cannot exceed the precision of f_tuple.

* @param max_iter An optional maximum number of iterations to perform. On exit,
* this is updated to the actual number of iterations performed
* @param args Parameter pack of arguments to pass the the functors in `f_tuple`
*/
template <typename FRootFunc, typename SolverFun, typename GuessScalar,
typename MinScalar, typename MaxScalar, typename... Types,
require_all_not_st_var<GuessScalar, MinScalar, MaxScalar,
Types...>* = nullptr>
inline auto root_finder_tol(const GuessScalar guess, const MinScalar min,
const MaxScalar max, const int digits,
std::uintmax_t& max_iter, Types&&... args) {
check_bounded("root_finder", "initial guess", guess, min, max);
check_positive("root_finder", "digits", digits);
check_positive("root_finder", "max_iter", max_iter);
using ret_t = return_type_t<GuessScalar, MinScalar, MaxScalar, Types...>;
ret_t ret = 0;
auto f_plus_div
= internal::make_root_func<true, FRootFunc>(std::forward<Types>(args)...);
try {
ret = std::decay_t<SolverFun>::run(f_plus_div, ret_t(guess), ret_t(min),
ret_t(max), digits, max_iter);
} catch (const std::exception& e) {
throw e;
}
return ret;
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_halley_tol(const GuessScalar guess, const MinScalar min,
const MaxScalar max, const int digits,
std::uintmax_t& max_iter, Types&&... args) {
return root_finder_tol<FRootFunc, internal::HalleyRootSolver>(
guess, min, max, digits, max_iter, std::forward<Types>(args)...);
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_newton_raphson_tol(
const GuessScalar guess, const MinScalar min, const MaxScalar max,
const int digits, std::uintmax_t& max_iter, Types&&... args) {
return root_finder_tol<FRootFunc, internal::NewtonRootSolver>(
guess, min, max, digits, max_iter, std::forward<Types>(args)...);
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_schroder_tol(const GuessScalar guess,
const MinScalar min, const MaxScalar max,
const int digits, std::uintmax_t& max_iter,
Types&&... args) {
return root_finder_tol<FRootFunc, internal::SchroderRootSolver>(
guess, min, max, digits, max_iter, std::forward<Types>(args)...);
}

/**
* Solve for root with default values for the tolerances
* @tparam FRootFunc A struct or class with a static function called `run`.
* The structs `run` function must have a boolean template parameter that
* when `true` returns a tuple containing the function result and the
* derivatives needed for the root finder. When the boolean template parameter
* is `false` the function should return a single value containing the function
* result.
* @tparam SolverFun One of the three struct types used to call the root solver.
* (`NewtonRootSolver`, `HalleyRootSolver`, `SchroderRootSolver`).
* @tparam GuessScalar Scalar type
* @tparam MinScalar Scalar type
* @tparam MaxScalar Scalar type
* @tparam Types Arg types to pass to functors in `f_tuple`
* @param guess An initial guess at the root value
* @param min The minimum possible value for the result, this is used as an
* initial lower bracket
* @param max The maximum possible value for the result, this is used as an
* initial upper bracket
* @param args Parameter pack of arguments to pass the the functors in `f_tuple`
*/
template <typename FRootFunc, typename SolverFun, typename GuessScalar,
typename MinScalar, typename MaxScalar, typename... Types>
inline auto root_finder(const GuessScalar guess, const MinScalar min,
const MaxScalar max, Types&&... args) {
constexpr int digits = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this default chosen? Maybe add one sentence about this choice in the doxygen doc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to update this to be in line with what boost's docs say

The value of digits is crucial to good performance of these functions, if it is set too high then at best you will get one extra (unnecessary) iteration, and at worst the last few steps will proceed by bisection. Remember that the returned value can never be more accurate than f(x) can be evaluated, and that if f(x) suffers from cancellation errors as it tends to zero then the computed steps will be effectively random. The value of digits should be set so that iteration terminates before this point: remember that for second and third order methods the number of correct digits in the result is increasing quite substantially with each iteration, digits should be set by experiment so that the final iteration just takes the next value into the zone where f(x) becomes inaccurate. A good starting point for digits would be 0.6D for Newton and 0.4D for Halley or Shröder iteration, where D is std::numeric_limits::digits.

https://www.boost.org/doc/libs/1_62_0/libs/math/doc/html/math_toolkit/roots/roots_deriv.html

std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max();
return root_finder_tol<FRootFunc, SolverFun>(
guess, min, max, digits, max_iter, std::forward<Types>(args)...);
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_hailey(const GuessScalar guess, const MinScalar min,
const MaxScalar max, Types&&... args) {
constexpr int digits = 16;
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max();
return root_finder_halley_tol<FRootFunc>(guess, min, max, digits, max_iter,
std::forward<Types>(args)...);
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_newton_raphson(const GuessScalar guess,
const MinScalar min, const MaxScalar max,
Types&&... args) {
constexpr int digits = 16;
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max();
return root_finder_newton_raphson_tol<FRootFunc>(
guess, min, max, digits, max_iter, std::forward<Types>(args)...);
}

template <typename FRootFunc, typename GuessScalar, typename MinScalar,
typename MaxScalar, typename... Types>
inline auto root_finder_schroder(const GuessScalar guess, const MinScalar min,
const MaxScalar max, Types&&... args) {
constexpr int digits = 16;
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max();
return root_finder_schroder_tol<FRootFunc>(guess, min, max, digits, max_iter,
std::forward<Types>(args)...);
}

} // namespace math
} // namespace stan
#endif
2 changes: 2 additions & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include <stan/math/rev/fun/fmax.hpp>
#include <stan/math/rev/fun/fmin.hpp>
#include <stan/math/rev/fun/fmod.hpp>
#include <stan/math/rev/fun/frexp.hpp>
#include <stan/math/rev/fun/from_var_value.hpp>
#include <stan/math/rev/fun/gamma_p.hpp>
#include <stan/math/rev/fun/gamma_q.hpp>
Expand Down Expand Up @@ -161,6 +162,7 @@
#include <stan/math/rev/fun/singular_values.hpp>
#include <stan/math/rev/fun/svd_U.hpp>
#include <stan/math/rev/fun/svd_V.hpp>
#include <stan/math/rev/fun/sign.hpp>
#include <stan/math/rev/fun/sinh.hpp>
#include <stan/math/rev/fun/softmax.hpp>
#include <stan/math/rev/fun/sqrt.hpp>
Expand Down
14 changes: 14 additions & 0 deletions stan/math/rev/fun/frexp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef STAN_MATH_REV_FUN_FREXP_HPP
#define STAN_MATH_REV_FUN_FREXP_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>

namespace stan {
namespace math {
inline auto frexp(stan::math::var x, int* exponent) noexcept {
return std::frexp(x.val(), exponent);
}
} // namespace math
} // namespace stan
#endif
12 changes: 12 additions & 0 deletions stan/math/rev/fun/sign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef STAN_MATH_REV_FUN_SIGN_HPP
#define STAN_MATH_REV_FUN_SIGN_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>

namespace stan {
namespace math {
inline int sign(stan::math::var z) { return (z == 0) ? 0 : z < 0 ? -1 : 1; }
} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <stan/math/rev/functor/map_rect_reduce.hpp>
#include <stan/math/rev/functor/operands_and_partials.hpp>
#include <stan/math/rev/functor/reduce_sum.hpp>
#include <stan/math/rev/functor/root_finder.hpp>
#include <stan/math/rev/functor/finite_diff_hessian_auto.hpp>

#endif
Loading