From b3c2eeaec8b67dfbb05840b53cccfcb562e38313 Mon Sep 17 00:00:00 2001 From: "A. Jiang" Date: Fri, 27 Sep 2024 18:04:36 +0800 Subject: [PATCH] [libc++] Constrain additional overloads of `pow` for `complex` harder --- libcxx/include/complex | 6 +- .../complex.number/cmplx.over.pow.pass.cpp | 106 ++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp diff --git a/libcxx/include/complex b/libcxx/include/complex index 4030d96b003d56..15e42800fbfa0a 100644 --- a/libcxx/include/complex +++ b/libcxx/include/complex @@ -1097,20 +1097,20 @@ inline _LIBCPP_HIDE_FROM_ABI complex<_Tp> pow(const complex<_Tp>& __x, const com return std::exp(__y * std::log(__x)); } -template +template ::value && is_floating_point<_Up>::value, int> = 0> inline _LIBCPP_HIDE_FROM_ABI complex::type> pow(const complex<_Tp>& __x, const complex<_Up>& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); } -template ::value, int> = 0> +template ::value && is_arithmetic<_Up>::value, int> = 0> inline _LIBCPP_HIDE_FROM_ABI complex::type> pow(const complex<_Tp>& __x, const _Up& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); } -template ::value, int> = 0> +template ::value && is_floating_point<_Up>::value, int> = 0> inline _LIBCPP_HIDE_FROM_ABI complex::type> pow(const _Tp& __x, const complex<_Up>& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); diff --git a/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp new file mode 100644 index 00000000000000..64e679fed7435c --- /dev/null +++ b/libcxx/test/libcxx/numerics/complex.number/cmplx.over.pow.pass.cpp @@ -0,0 +1,106 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// + +// template complex<__promote::type> pow(const complex&, const U&); +// template complex<__promote::type> pow(const complex&, const complex&); +// template complex<__promote::type> pow(const T&, const complex&); + +// Test that these additional overloads are free from catching std::complex, +// which is expected by several 3rd party libraries, see https://github.com/llvm/llvm-project/issues/109858. + +#include +#include +#include +#include + +#include "test_macros.h" + +namespace usr { +struct usr_tag {}; + +template +TEST_CONSTEXPR + typename std::enable_if<(std::is_same::value && std::is_floating_point::value) || + (std::is_floating_point::value && std::is_same::value), + int>::type + pow(const T&, const std::complex&) { + return std::is_same::value ? 0 : 1; +} + +template +TEST_CONSTEXPR + typename std::enable_if<(std::is_same::value && std::is_floating_point::value) || + (std::is_floating_point::value && std::is_same::value), + int>::type + pow(const std::complex&, const U&) { + return std::is_same::value ? 2 : 3; +} + +template +TEST_CONSTEXPR + typename std::enable_if<(std::is_same::value && std::is_floating_point::value) || + (std::is_floating_point::value && std::is_same::value), + int>::type + pow(const std::complex&, const std::complex&) { + return std::is_same::value ? 4 : 5; +} +} // namespace usr + +int main(int, char**) { + using std::pow; + using usr::pow; + + TEST_CONSTEXPR usr::usr_tag tag; + TEST_CONSTEXPR_CXX14 const std::complex ctag; + + assert(pow(tag, std::complex(1.0f)) == 0); + assert(pow(std::complex(1.0f), tag) == 2); + assert(pow(tag, std::complex(1.0)) == 0); + assert(pow(std::complex(1.0), tag) == 2); + assert(pow(tag, std::complex(1.0l)) == 0); + assert(pow(std::complex(1.0l), tag) == 2); + + assert(pow(1.0f, ctag) == 1); + assert(pow(ctag, 1.0f) == 3); + assert(pow(1.0, ctag) == 1); + assert(pow(ctag, 1.0) == 3); + assert(pow(1.0l, ctag) == 1); + assert(pow(ctag, 1.0l) == 3); + + assert(pow(ctag, std::complex(1.0f)) == 4); + assert(pow(std::complex(1.0f), ctag) == 5); + assert(pow(ctag, std::complex(1.0)) == 4); + assert(pow(std::complex(1.0), ctag) == 5); + assert(pow(ctag, std::complex(1.0l)) == 4); + assert(pow(std::complex(1.0l), ctag) == 5); + +#if TEST_STD_VER >= 11 + static_assert(pow(tag, std::complex(1.0f)) == 0, ""); + static_assert(pow(std::complex(1.0f), tag) == 2, ""); + static_assert(pow(tag, std::complex(1.0)) == 0, ""); + static_assert(pow(std::complex(1.0), tag) == 2, ""); + static_assert(pow(tag, std::complex(1.0l)) == 0, ""); + static_assert(pow(std::complex(1.0l), tag) == 2, ""); + + static_assert(pow(1.0f, ctag) == 1, ""); + static_assert(pow(ctag, 1.0f) == 3, ""); + static_assert(pow(1.0, ctag) == 1, ""); + static_assert(pow(ctag, 1.0) == 3, ""); + static_assert(pow(1.0l, ctag) == 1, ""); + static_assert(pow(ctag, 1.0l) == 3, ""); + + static_assert(pow(ctag, std::complex(1.0f)) == 4, ""); + static_assert(pow(std::complex(1.0f), ctag) == 5, ""); + static_assert(pow(ctag, std::complex(1.0)) == 4, ""); + static_assert(pow(std::complex(1.0), ctag) == 5, ""); + static_assert(pow(ctag, std::complex(1.0l)) == 4, ""); + static_assert(pow(std::complex(1.0l), ctag) == 5, ""); +#endif +}