Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add perfect forwarding and constexpr to reverse mode functions #3092

Open
wants to merge 33 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f164443
adds perfect forwarding and uses constexpr in functions
SteveBronder Jul 11, 2024
d70fb0f
fix return for eigenvector_sym
SteveBronder Jul 11, 2024
f08c711
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 12, 2024
ddea3c0
fix bad formatting for clang-format
SteveBronder Jul 15, 2024
d98f4f0
Merge commit '9052db82c3fafce144c355c757a9b7e44b884b66' into HEAD
yashikno Jul 15, 2024
3012408
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 15, 2024
af37cbc
clean up type traits to change \!is_constant with is_autodiffable
SteveBronder Jul 15, 2024
5152942
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 15, 2024
888e2cc
force c++17
SteveBronder Jul 15, 2024
9887597
Merge remote-tracking branch 'refs/remotes/origin/feature/pf-funcs-co…
SteveBronder Jul 15, 2024
d4ebc3d
remove unused type alias
SteveBronder Jul 15, 2024
d7350f2
add check_vari_on_stack for arena matrix
SteveBronder Jul 15, 2024
a54fb01
update type traits for fft, square_dist, and trace funcs
SteveBronder Jul 15, 2024
bf36144
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 15, 2024
5d989a2
adds test framework for cleaning memory after autodiff calls
SteveBronder Jul 16, 2024
fdb5d03
update forward for ref in quad_form_sym
SteveBronder Jul 16, 2024
be242f0
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 16, 2024
aea8afa
Merge remote-tracking branch 'origin/develop' into feature/pf-funcs-c…
SteveBronder Jul 17, 2024
d5bcc83
update wrt review comments
SteveBronder Jul 17, 2024
a3f3cd8
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 17, 2024
c3d8136
update columns dot product for complex types
SteveBronder Jul 18, 2024
ab679fe
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 18, 2024
ad456a1
update columns_dot_product
SteveBronder Jul 18, 2024
bc96704
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 18, 2024
d1fc936
update
SteveBronder Jul 19, 2024
5c8bd28
Merge remote-tracking branch 'refs/remotes/origin/feature/pf-funcs-co…
SteveBronder Jul 19, 2024
17ffd02
start working on the constrain functions
SteveBronder Jul 19, 2024
506fe50
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 19, 2024
c4c7605
fix move semantics in unit_vector_constrain
SteveBronder Jul 22, 2024
82fc4f6
update ordered constrain
SteveBronder Jul 23, 2024
c826d98
fix use of x after forwarding in ordered_constrain
SteveBronder Jul 25, 2024
d5de333
update
SteveBronder Jul 29, 2024
855546a
Merge remote-tracking branch 'origin/develop' into feature/pf-funcs-c…
SteveBronder Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,8 @@ INC_GTEST ?= -I $(GTEST)/include -I $(GTEST)
CPPFLAGS_BOOST ?= -DBOOST_DISABLE_ASSERTS
CPPFLAGS_SUNDIALS ?= -DNO_FPRINTF_OUTPUT $(CPPFLAGS_OPTIM_SUNDIALS) $(CXXFLAGS_FLTO_SUNDIALS)
#CPPFLAGS_GTEST ?=
STAN_HAS_CXX17 ?= false
ifeq ($(CXX_TYPE), gcc)
GCC_GE_73 := $(shell [ $(CXX_MAJOR) -gt 7 -o \( $(CXX_MAJOR) -eq 7 -a $(CXX_MINOR) -ge 1 \) ] && echo true)
ifeq ($(GCC_GE_73),true)
STAN_HAS_CXX17 := true
endif
else ifeq ($(CXX_TYPE), clang)
CLANG_GE_5 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true)
ifeq ($(CLANG_GE_5),true)
STAN_HAS_CXX17 := true
endif
else ifeq ($(CXX_TYPE), mingw32-gcc)
MINGW_GE_50 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true)
ifeq ($(MINGW_GE_50),true)
STAN_HAS_CXX17 := true
endif
endif

ifeq ($(STAN_HAS_CXX17), true)
CXXFLAGS_LANG ?= -std=c++17
CXXFLAGS_STANDARD ?= c++17
else
$(warning "Stan cannot detect if your compiler has the C++17 standard. If it does, please set STAN_HAS_CXX17=true in your make/local file. C++17 support is mandatory in the next release of Stan. Defaulting to C++14")
CXXFLAGS_LANG ?= -std=c++1y
CXXFLAGS_STANDARD ?= c++1y
endif
CXXFLAGS_LANG ?= -std=c++17
CXXFLAGS_STANDARD ?= c++17
#CXXFLAGS_BOOST ?=
CXXFLAGS_SUNDIALS ?= -pipe $(CXXFLAGS_OPTIM_SUNDIALS) $(CPPFLAGS_FLTO_SUNDIALS)
#CXXFLAGS_GTEST
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/grad_reg_inc_gamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ namespace math {
(a-1)_k\right) \frac{1}{z^k} \end{array} \f]
*/
template <typename T1, typename T2>
return_type_t<T1, T2> grad_reg_inc_gamma(T1 a, T2 z, T1 g, T1 dig,
double precision = 1e-6,
int max_steps = 1e5) {
inline return_type_t<T1, T2> grad_reg_inc_gamma(T1 a, T2 z, T1 g, T1 dig,
double precision = 1e-6,
int max_steps = 1e5) {
using std::exp;
using std::fabs;
using std::log;
Expand Down
8 changes: 8 additions & 0 deletions stan/math/prim/meta/is_autodiff.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ struct is_autodiff
: bool_constant<math::disjunction<is_var<std::decay_t<T>>,
is_fvar<std::decay_t<T>>>::value> {};

template <typename... Types>
inline constexpr bool is_autodiff_v
= math::conjunction<is_autodiff<Types>...>::value;

template <typename... Types>
inline constexpr bool is_autodiffable_v
= math::conjunction<is_autodiff<scalar_type_t<Types>>...>::value;

Comment on lines +22 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two typedefs seem to test the same thing. Also, shouldn't this be is_autodiff_all_v?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need better names. is_autodiff_v<> looks to see is the type is a fvar or var and fails otherwise. is_autodiffable_v looks into the scalar type of the type to see if it is autodiffable, so things like eigen matrices and vectors of var or fvar types would be true for is_autodiffable_v.

I left is_autodiff alone to not mess with the other functions that use it (mostly functions that use it in a requires). Maybe the current is_autodiff should be named is_autodiff_scalar?

/*! \ingroup require_stan_scalar_real */
/*! \defgroup autodiff_types autodiff */
/*! \addtogroup autodiff_types */
Expand Down
7 changes: 7 additions & 0 deletions stan/math/prim/meta/is_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,12 @@ template <typename T>
struct is_constant<T, require_eigen_t<T>>
: bool_constant<is_constant<typename std::decay_t<T>::Scalar>::value> {};

template <typename... Types>
inline constexpr bool is_constant_all_v = is_constant_all<Types...>::value;

template <typename... Types>
inline constexpr bool is_constant_v
= std::conjunction<is_constant<Types>...>::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
= std::conjunction<is_constant<Types>...>::value;
= math::conjunction<is_constant<Types>...>::value;

For consistency

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also same comment about redundancy - is there a situation where is_constant_all_v would behave differently to is_constant_v here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think we should start using the std version! We just used the stan math version because we didn't have c++17 available previously

Also same comment about redundancy - is there a situation where is_constant_all_v would behave differently to is_constant_v here?

Since I made is_constant_all_v to accept multiple types there should be no difference. I'm going to put up another PR where throughout the math library I change is_constant_all with is_constant_v. Which should clean up a lot.


} // namespace stan
#endif
4 changes: 4 additions & 0 deletions stan/math/prim/meta/is_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ template <typename T>
struct is_matrix
: bool_constant<math::disjunction<is_rev_matrix<T>, is_eigen<T>>::value> {};

template <typename... Types>
inline constexpr bool is_matrix_v
= stan::math::conjunction<is_matrix<Types>...>::value;

/*! \ingroup require_eigens_types */
/*! \defgroup matrix_types matrix */
/*! \addtogroup matrix_types */
Expand Down
4 changes: 4 additions & 0 deletions stan/math/prim/meta/is_stan_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ struct is_stan_scalar
is_fvar<std::decay_t<T>>, std::is_arithmetic<std::decay_t<T>>,
is_complex<std::decay_t<T>>>::value> {};

template <typename... Types>
inline constexpr bool is_stan_scalar_v
= std::conjunction<is_stan_scalar<Types>...>::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
= std::conjunction<is_stan_scalar<Types>...>::value;
= math::conjunction<is_stan_scalar<Types>...>::value;


/*! \ingroup require_stan_scalar_real */
/*! \defgroup stan_scalar_types stan_scalar */
/*! \addtogroup stan_scalar_types */
Expand Down
28 changes: 14 additions & 14 deletions stan/math/rev/fun/append_col.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,24 @@ template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
inline auto append_col(const T1& A, const T2& B) {
check_size_match("append_col", "columns of A", A.rows(), "columns of B",
B.rows());
if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
arena_t<promote_scalar_t<var, T2>> arena_B = B;
if constexpr (is_autodiffable_v<T1, T2>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be !is_constant_all_v<T1, T2> for consistency with the original definition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Applies to the other changes as well)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm confused on definitions. If something is not constant then shouldn't it be autodiffable (i.e. var and fvar?)

arena_t<T1> arena_A = A;
arena_t<T2> arena_B = B;
return make_callback_var(
append_col(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().leftCols(arena_A.cols());
arena_B.adj() += vi.adj().rightCols(arena_B.cols());
});
} else if (!is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
} else if constexpr (is_autodiffable_v<T1>) {
arena_t<T1> arena_A = A;
return make_callback_var(append_col(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable {
arena_A.adj()
+= vi.adj().leftCols(arena_A.cols());
});
} else {
arena_t<promote_scalar_t<var, T2>> arena_B = B;
arena_t<T2> arena_B = B;
return make_callback_var(append_col(value_of(A), value_of(arena_B)),
[arena_B](auto& vi) mutable {
arena_B.adj()
Expand All @@ -79,21 +79,21 @@ template <typename Scal, typename RowVec,
require_stan_scalar_t<Scal>* = nullptr,
require_t<is_eigen_row_vector<RowVec>>* = nullptr>
inline auto append_col(const Scal& A, const var_value<RowVec>& B) {
if (!is_constant<Scal>::value && !is_constant<RowVec>::value) {
if constexpr (is_autodiffable_v<Scal, RowVec>) {
var arena_A = A;
arena_t<promote_scalar_t<var, RowVec>> arena_B = B;
arena_t<RowVec> arena_B = B;
return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().coeff(0);
arena_B.adj() += vi.adj().tail(arena_B.size());
});
} else if (!is_constant<Scal>::value) {
} else if constexpr (is_autodiffable_v<Scal>) {
var arena_A = A;
return make_callback_var(
append_col(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable { arena_A.adj() += vi.adj().coeff(0); });
} else {
arena_t<promote_scalar_t<var, RowVec>> arena_B = B;
arena_t<RowVec> arena_B = B;
return make_callback_var(append_col(value_of(A), value_of(arena_B)),
[arena_B](auto& vi) mutable {
arena_B.adj() += vi.adj().tail(arena_B.size());
Expand All @@ -119,17 +119,17 @@ template <typename RowVec, typename Scal,
require_t<is_eigen_row_vector<RowVec>>* = nullptr,
require_stan_scalar_t<Scal>* = nullptr>
inline auto append_col(const var_value<RowVec>& A, const Scal& B) {
if (!is_constant<RowVec>::value && !is_constant<Scal>::value) {
arena_t<promote_scalar_t<var, RowVec>> arena_A = A;
if constexpr (is_autodiffable_v<RowVec, Scal>) {
arena_t<RowVec> arena_A = A;
var arena_B = B;
return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().head(arena_A.size());
arena_B.adj()
+= vi.adj().coeff(vi.adj().size() - 1);
});
} else if (!is_constant<RowVec>::value) {
arena_t<promote_scalar_t<var, RowVec>> arena_A = A;
} else if constexpr (is_autodiffable_v<RowVec>) {
arena_t<RowVec> arena_A = A;
return make_callback_var(append_col(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable {
arena_A.adj() += vi.adj().head(arena_A.size());
Expand Down
30 changes: 15 additions & 15 deletions stan/math/rev/fun/append_row.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
inline auto append_row(const T1& A, const T2& B) {
check_size_match("append_row", "columns of A", A.cols(), "columns of B",
B.cols());
if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
arena_t<promote_scalar_t<var, T2>> arena_B = B;
if constexpr (is_autodiffable_v<T1, T2>) {
arena_t<T1> arena_A = A;
arena_t<T2> arena_B = B;
return make_callback_var(
append_row(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().topRows(arena_A.rows());
arena_B.adj() += vi.adj().bottomRows(arena_B.rows());
});
} else if (!is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
} else if constexpr (is_autodiffable_v<T1>) {
arena_t<T1> arena_A = A;
return make_callback_var(append_row(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable {
arena_A.adj()
+= vi.adj().topRows(arena_A.rows());
});
} else {
arena_t<promote_scalar_t<var, T2>> arena_B = B;
arena_t<T2> arena_B = B;
return make_callback_var(append_row(value_of(A), value_of(arena_B)),
[arena_B](auto& vi) mutable {
arena_B.adj()
Expand All @@ -76,21 +76,21 @@ template <typename Scal, typename ColVec,
require_stan_scalar_t<Scal>* = nullptr,
require_t<is_eigen_col_vector<ColVec>>* = nullptr>
inline auto append_row(const Scal& A, const var_value<ColVec>& B) {
if (!is_constant<Scal>::value && !is_constant<ColVec>::value) {
if constexpr (is_autodiffable_v<Scal, ColVec>) {
var arena_A = A;
arena_t<promote_scalar_t<var, ColVec>> arena_B = B;
arena_t<ColVec> arena_B = B;
return make_callback_var(append_row(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().coeff(0);
arena_B.adj() += vi.adj().tail(arena_B.size());
});
} else if (!is_constant<Scal>::value) {
} else if constexpr (is_autodiffable_v<Scal>) {
var arena_A = A;
return make_callback_var(
append_row(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable { arena_A.adj() += vi.adj().coeff(0); });
} else {
arena_t<promote_scalar_t<var, ColVec>> arena_B = B;
arena_t<ColVec> arena_B = B;
return make_callback_var(append_row(value_of(A), value_of(arena_B)),
[arena_B](auto& vi) mutable {
arena_B.adj() += vi.adj().tail(arena_B.size());
Expand All @@ -115,23 +115,23 @@ template <typename ColVec, typename Scal,
require_t<is_eigen_col_vector<ColVec>>* = nullptr,
require_stan_scalar_t<Scal>* = nullptr>
inline auto append_row(const var_value<ColVec>& A, const Scal& B) {
if (!is_constant<ColVec>::value && !is_constant<Scal>::value) {
arena_t<promote_scalar_t<var, ColVec>> arena_A = A;
if constexpr (is_autodiffable_v<ColVec, Scal>) {
arena_t<ColVec> arena_A = A;
var arena_B = B;
return make_callback_var(append_row(value_of(arena_A), value_of(arena_B)),
[arena_A, arena_B](auto& vi) mutable {
arena_A.adj() += vi.adj().head(arena_A.size());
arena_B.adj()
+= vi.adj().coeff(vi.adj().size() - 1);
});
} else if (!is_constant<ColVec>::value) {
arena_t<promote_scalar_t<var, ColVec>> arena_A = A;
} else if constexpr (is_autodiffable_v<ColVec>) {
arena_t<ColVec> arena_A = A;
return make_callback_var(append_row(value_of(arena_A), value_of(B)),
[arena_A](auto& vi) mutable {
arena_A.adj() += vi.adj().head(arena_A.size());
});
} else {
arena_t<promote_scalar_t<var, Scal>> arena_B = B;
arena_t<Scal> arena_B = B;
return make_callback_var(append_row(value_of(A), value_of(arena_B)),
[arena_B](auto& vi) mutable {
arena_B.adj()
Expand Down
Loading