Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incomplete Beta Function Inverse #2637

Merged
merged 17 commits into from
Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
#include <stan/math/fwd/fun/hypot.hpp>
#include <stan/math/fwd/fun/inc_beta.hpp>
#include <stan/math/fwd/fun/inc_beta_inv.hpp>
#include <stan/math/fwd/fun/inv.hpp>
#include <stan/math/fwd/fun/inv_Phi.hpp>
#include <stan/math/fwd/fun/inv_cloglog.hpp>
Expand Down
87 changes: 87 additions & 0 deletions stan/math/fwd/fun/inc_beta_inv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#ifndef STAN_MATH_FWD_FUN_INC_BETA_INV_HPP
#define STAN_MATH_FWD_FUN_INC_BETA_INV_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/inc_beta_inv.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the cumulative density function for the beta
* distribution.
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
template <typename T1, typename T2, typename T3,
require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
require_any_fvar_t<T1, T2, T3>* = nullptr>
inline fvar<partials_return_t<T1, T2, T3>> inc_beta_inv(const T1& a,
const T2& b,
const T3& p) {
using T_return = partials_return_t<T1, T2, T3>;
auto a_val = value_of(a);
auto b_val = value_of(b);
auto p_val = value_of(p);
T_return w = inc_beta_inv(a_val, b_val, p_val);
T_return log_w = log(w);
T_return log1m_w = log1m(w);
auto one_m_a = 1 - a_val;
auto one_m_b = 1 - b_val;
T_return one_m_w = 1 - w;
auto ap1 = a_val + 1;
auto bp1 = b_val + 1;
auto lbeta_ab = lbeta(a_val, b_val);
auto digamma_apb = digamma(a_val + b_val);

T_return inv_d_(0);

if (is_fvar<T1>::value) {
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
auto da2
= exp(a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1));
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
}

if (is_fvar<T2>::value) {
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
auto db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
* (log1m_w - digamma(b_val) + digamma_apb);

inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3);
}

if (is_fvar<T3>::value) {
inv_d_ += forward_as<fvar<T_return>>(p).d_
* exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
}

return fvar<T_return>(w, inv_d_);
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include <stan/math/prim/fun/if_else.hpp>
#include <stan/math/prim/fun/imag.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/inc_beta_inv.hpp>
#include <stan/math/prim/fun/initialize.hpp>
#include <stan/math/prim/fun/initialize_fill.hpp>
#include <stan/math/prim/fun/int_step.hpp>
Expand Down
22 changes: 14 additions & 8 deletions stan/math/prim/fun/F32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,35 @@ namespace math {
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
*/
template <typename T>
T F32(const T& a1, const T& a2, const T& a3, const T& b1, const T& b2,
const T& z, double precision = 1e-6, int max_steps = 1e5) {
template <typename Ta1, typename Ta2, typename Ta3, typename Tb1, typename Tb2,
typename Tz>
return_type_t<Ta1, Ta2, Ta3, Tb1, Tb2, Tz> F32(const Ta1& a1, const Ta2& a2,
const Ta3& a3, const Tb1& b1,
const Tb2& b2, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
check_3F2_converges("F32", a1, a2, a3, b1, b2, z);

using T_return = return_type_t<Ta1, Ta2, Ta3, Tb1, Tb2, Tz>;
using std::exp;
using std::fabs;
using std::log;

T t_acc = 1.0;
T log_t = 0.0;
T log_z = log(z);
T_return t_acc = 1.0;
T_return log_t = 0.0;
Tz log_z = log(z);
double t_sign = 1.0;

for (int k = 0; k <= max_steps; ++k) {
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (k + 1));
T_return p
= (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (k + 1));
if (p == 0.0) {
return t_acc;
}

log_t += log(fabs(p)) + log_z;
t_sign = p >= 0.0 ? t_sign : -t_sign;
T t_new = t_sign > 0.0 ? exp(log_t) : -exp(log_t);
T_return t_new = t_sign > 0.0 ? exp(log_t) : -exp(log_t);
t_acc += t_new;

if (fabs(t_new) <= precision) {
Expand Down
34 changes: 34 additions & 0 deletions stan/math/prim/fun/inc_beta_inv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef STAN_MATH_PRIM_FUN_INC_BETA_INV_HPP
#define STAN_MATH_PRIM_FUN_INC_BETA_INV_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <boost/math/special_functions/beta.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the cumulative density function for the beta
* distribution.
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
inline double inc_beta_inv(double a, double b, double p) {
check_not_nan("inc_beta", "a", a);
check_not_nan("inc_beta", "b", b);
check_not_nan("inc_beta", "p", p);
return boost::math::ibeta_inv(a, b, p, boost_policy_t<>());
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
#include <stan/math/rev/fun/identity_free.hpp>
#include <stan/math/rev/fun/if_else.hpp>
#include <stan/math/rev/fun/inc_beta.hpp>
#include <stan/math/rev/fun/inc_beta_inv.hpp>
#include <stan/math/rev/fun/initialize_fill.hpp>
#include <stan/math/rev/fun/initialize_variable.hpp>
#include <stan/math/rev/fun/inv.hpp>
Expand Down
85 changes: 85 additions & 0 deletions stan/math/rev/fun/inc_beta_inv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#ifndef STAN_MATH_REV_FUN_INC_BETA_INV_HPP
#define STAN_MATH_REV_FUN_INC_BETA_INV_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/inc_beta_inv.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/is_any_nan.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the cumulative density function for the beta
andrjohns marked this conversation as resolved.
Show resolved Hide resolved
* distribution.
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
template <typename T1, typename T2, typename T3,
require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
require_any_var_t<T1, T2, T3>* = nullptr>
inline var inc_beta_inv(const T1& a, const T2& b, const T3& p) {
double a_val = value_of(a);
double b_val = value_of(b);
double p_val = value_of(p);
double w = inc_beta_inv(a_val, b_val, p_val);
return make_callback_var(w, [a, b, p, a_val, b_val, p_val, w](auto& vi) {
double log_w = log(w);
double log1m_w = log1m(w);
double one_m_a = 1 - a_val;
double one_m_b = 1 - b_val;
double one_m_w = 1 - w;
double ap1 = a_val + 1;
double bp1 = b_val + 1;
double lbeta_ab = lbeta(a_val, b_val);
double digamma_apb = digamma(a_val + b_val);

if (!is_constant_all<T1>::value) {
double da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
double da2 = a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w))
- 2 * lgamma(ap1);
double da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);

forward_as<var>(a).adj() += vi.adj() * da1 * (exp(da2) - da3);
}

if (!is_constant_all<T2>::value) {
double db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
double db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

double db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
* (log1m_w - digamma(b_val) + digamma_apb);

forward_as<var>(b).adj() += vi.adj() * db1 * (exp(db2) - db3);
}

if (!is_constant_all<T3>::value) {
forward_as<var>(p).adj()
+= vi.adj() * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
}
});
}

} // namespace math
} // namespace stan
#endif
33 changes: 33 additions & 0 deletions test/unit/math/fwd/fun/inc_beta_inv_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <stan/math/fwd.hpp>
#include <gtest/gtest.h>

TEST(AgradFwdMatrixIncBetaInv, fd_scalar) {
using stan::math::fvar;
using stan::math::inc_beta_inv;
fvar<double> a = 6;
fvar<double> b = 2;
fvar<double> p = 0.9;
a.d_ = 1.0;
b.d_ = 1.0;
p.d_ = 1.0;

fvar<double> res = inc_beta_inv(a, b, p);

EXPECT_FLOAT_EQ(res.d_, 0.0117172527399 - 0.0680999818473 + 0.455387298585);
}

TEST(AgradFwdMatrixIncBetaInv, ffd_scalar) {
using stan::math::fvar;
using stan::math::inc_beta_inv;
fvar<fvar<double>> a = 7;
fvar<fvar<double>> b = 4;
fvar<fvar<double>> p = 0.15;
a.val_.d_ = 1.0;
b.val_.d_ = 1.0;
p.val_.d_ = 1.0;

fvar<fvar<double>> res = inc_beta_inv(a, b, p);

EXPECT_FLOAT_EQ(res.val_.d_,
0.0428905418857 - 0.0563420377808 + 0.664919819507);
}
79 changes: 79 additions & 0 deletions test/unit/math/mix/fun/inc_beta_inv_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <stan/math/mix.hpp>
#include <gtest/gtest.h>
#include <test/unit/math/rev/fun/util.hpp>

TEST(ProbInternalMath, inc_beta_inv_fv1) {
using stan::math::fvar;
using stan::math::inc_beta_inv;
using stan::math::var;
double a_d = 1;
double b_d = 2;
double p_d = 0.5;
fvar<var> a_v = a_d;
fvar<var> b_v = b_d;
fvar<var> p_v = p_d;
a_v.d_ = 1.0;
b_v.d_ = 1.0;
p_v.d_ = 1.0;

fvar<var> res = inc_beta_inv(a_v, b_v, p_v);
res.val_.grad();

EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597);
EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934);
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187);

a_v = a_d;
b_v = b_d;
p_v = p_d;
a_v.d_ = 1.0;
b_v.d_ = 1.0;
p_v.d_ = 1.0;

res = inc_beta_inv(a_d, b_v, p_v);
res.val_.grad();

EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934);
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187);

b_v = b_d;
p_v = p_d;
b_v.d_ = 1.0;
p_v.d_ = 1.0;

res = inc_beta_inv(a_v, b_d, p_v);
res.val_.grad();

EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597);
EXPECT_FLOAT_EQ(p_v.val_.adj(), 0.707106781187);

a_v = a_d;
p_v = p_d;
a_v.d_ = 1.0;
p_v.d_ = 1.0;

res = inc_beta_inv(a_v, b_v, p_d);
res.val_.grad();

EXPECT_FLOAT_EQ(a_v.val_.adj(), 0.287698278597);
EXPECT_FLOAT_EQ(b_v.val_.adj(), -0.122532267934);
}

TEST(ProbInternalMath, inc_beta_inv_fv2) {
using stan::math::fvar;
using stan::math::inc_beta_inv;
using stan::math::var;
fvar<fvar<var>> a = 2;
fvar<fvar<var>> b = 5;
fvar<fvar<var>> p = 0.1;
a.d_ = 1.0;
b.d_ = 1.0;
p.d_ = 1.0;

fvar<fvar<var>> res = inc_beta_inv(a, b, p);
res.val_.val_.grad();

EXPECT_FLOAT_EQ(a.val_.val_.adj(), 0.0783025374798);
EXPECT_FLOAT_EQ(b.val_.val_.adj(), -0.0161882044585);
EXPECT_FLOAT_EQ(p.val_.val_.adj(), 0.530989359806);
}
Loading