aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-16 10:22:07 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-01-17 08:50:57 -0800
commitbde6741641b7c677d901cd48db844fcea1fd32fe (patch)
treef25e6540b247ed01df1bdadc41502fbc332321d3 /Eigen
parent21a8a2487c824e5ae05566f4fcc49540053b2702 (diff)
Improved std::complex sqrt and rsqrt.
Replaces `std::sqrt` with `complex_sqrt` for all platforms (previously `complex_sqrt` was only used for CUDA and MSVC), and implements custom `complex_rsqrt`. Also introduces `numext::rsqrt` to simplify implementation, and modified `numext::hypot` to adhere to IEEE IEC 6059 for special cases. The `complex_sqrt` and `complex_rsqrt` implementations were found to be significantly faster than `std::sqrt<std::complex<T>>` and `1/numext::sqrt<std::complex<T>>`. Benchmark file attached. ``` GCC 10, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 9.21 ns 9.21 ns 73225448 BM_StdSqrt<std::complex<float>> 17.1 ns 17.1 ns 40966545 BM_Sqrt<std::complex<double>> 8.53 ns 8.53 ns 81111062 BM_StdSqrt<std::complex<double>> 21.5 ns 21.5 ns 32757248 BM_Rsqrt<std::complex<float>> 10.3 ns 10.3 ns 68047474 BM_DivSqrt<std::complex<float>> 16.3 ns 16.3 ns 42770127 BM_Rsqrt<std::complex<double>> 11.3 ns 11.3 ns 61322028 BM_DivSqrt<std::complex<double>> 16.5 ns 16.5 ns 42200711 Clang 11, Intel Xeon, x86_64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 7.46 ns 7.45 ns 90742042 BM_StdSqrt<std::complex<float>> 16.6 ns 16.6 ns 42369878 BM_Sqrt<std::complex<double>> 8.49 ns 8.49 ns 81629030 BM_StdSqrt<std::complex<double>> 21.8 ns 21.7 ns 31809588 BM_Rsqrt<std::complex<float>> 8.39 ns 8.39 ns 82933666 BM_DivSqrt<std::complex<float>> 14.4 ns 14.4 ns 48638676 BM_Rsqrt<std::complex<double>> 9.83 ns 9.82 ns 70068956 BM_DivSqrt<std::complex<double>> 15.7 ns 15.7 ns 44487798 Clang 9, Pixel 2, aarch64: --------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_Sqrt<std::complex<float>> 24.2 ns 24.1 ns 28616031 BM_StdSqrt<std::complex<float>> 104 ns 103 ns 6826926 BM_Sqrt<std::complex<double>> 31.8 ns 31.8 ns 22157591 BM_StdSqrt<std::complex<double>> 128 ns 128 ns 5437375 BM_Rsqrt<std::complex<float>> 31.9 ns 31.8 ns 22384383 BM_DivSqrt<std::complex<float>> 99.2 ns 98.9 ns 7250438 BM_Rsqrt<std::complex<double>> 46.0 ns 45.8 ns 15338689 BM_DivSqrt<std::complex<double>> 119 ns 119 ns 5898944 ```
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/GenericPacketMath.h3
-rw-r--r--Eigen/src/Core/MathFunctions.h107
-rw-r--r--Eigen/src/Core/MathFunctionsImpl.h60
-rw-r--r--Eigen/src/Core/arch/CUDA/Complex.h13
-rw-r--r--Eigen/src/Core/arch/SSE/MathFunctions.h3
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h2
6 files changed, 127 insertions, 61 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index ec7d20e73..16119c1d8 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -250,8 +250,7 @@ template<> EIGEN_DEVICE_FUNC inline double pzero<double>(const double& a) {
template <typename RealScalar>
EIGEN_DEVICE_FUNC inline std::complex<RealScalar> ptrue(const std::complex<RealScalar>& /*a*/) {
- RealScalar b;
- b = ptrue(b);
+ RealScalar b = ptrue(RealScalar(0));
return std::complex<RealScalar>(b, b);
}
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index f64116a41..511a4276f 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -324,7 +324,7 @@ struct abs2_retval
};
/****************************************************************************
-* Implementation of sqrt *
+* Implementation of sqrt/rsqrt *
****************************************************************************/
template<typename Scalar>
@@ -341,8 +341,8 @@ struct sqrt_impl
// Complex sqrt defined in MathFunctionsImpl.h.
template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& a_x);
-// MSVC incorrectly handles inf cases.
-#if EIGEN_COMP_MSVC > 0
+// Custom implementation is faster than `std::sqrt`, works on
+// GPU, and correctly handles special cases (unlike MSVC).
template<typename T>
struct sqrt_impl<std::complex<T> >
{
@@ -352,7 +352,6 @@ struct sqrt_impl<std::complex<T> >
return complex_sqrt<T>(x);
}
};
-#endif
template<typename Scalar>
struct sqrt_retval
@@ -360,6 +359,29 @@ struct sqrt_retval
typedef Scalar type;
};
+// Default implementation relies on numext::sqrt, at bottom of file.
+template<typename T>
+struct rsqrt_impl;
+
+// Complex rsqrt defined in MathFunctionsImpl.h.
+template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& a_x);
+
+template<typename T>
+struct rsqrt_impl<std::complex<T> >
+{
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
+ {
+ return complex_rsqrt<T>(x);
+ }
+};
+
+template<typename Scalar>
+struct rsqrt_retval
+{
+ typedef Scalar type;
+};
+
/****************************************************************************
* Implementation of norm1 *
****************************************************************************/
@@ -623,36 +645,6 @@ struct expm1_impl {
}
};
-// Specialization for complex types that are not supported by std::expm1.
-template <typename RealScalar>
-struct expm1_impl<std::complex<RealScalar> > {
- EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
- const std::complex<RealScalar>& x) {
- EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
- RealScalar xr = x.real();
- RealScalar xi = x.imag();
- // expm1(z) = exp(z) - 1
- // = exp(x + i * y) - 1
- // = exp(x) * (cos(y) + i * sin(y)) - 1
- // = exp(x) * cos(y) - 1 + i * exp(x) * sin(y)
- // Imag(expm1(z)) = exp(x) * sin(y)
- // Real(expm1(z)) = exp(x) * cos(y) - 1
- // = exp(x) * cos(y) - 1.
- // = expm1(x) + exp(x) * (cos(y) - 1)
- // = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2)
-
- // TODO better use numext::expm1 and numext::sin (but that would require forward declarations or moving this specialization down).
- RealScalar erm1 = expm1_impl<RealScalar>::run(xr);
- RealScalar er = erm1 + RealScalar(1.);
- EIGEN_USING_STD(sin);
- RealScalar sin2 = sin(xi / RealScalar(2.));
- sin2 = sin2 * sin2;
- RealScalar s = sin(xi);
- RealScalar real_part = erm1 - RealScalar(2.) * er * sin2;
- return std::complex<RealScalar>(real_part, er * s);
- }
-};
-
template<typename Scalar>
struct expm1_retval
{
@@ -1421,6 +1413,14 @@ bool sqrt<bool>(const bool &x) { return x; }
SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt)
#endif
+/** \returns the reciprocal square root of \a x. **/
+template<typename T>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+T rsqrt(const T& x)
+{
+ return internal::rsqrt_impl<T>::run(x);
+}
+
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
@@ -1936,6 +1936,45 @@ template<> struct scalar_fuzzy_impl<bool>
};
+} // end namespace internal
+
+// Default implementations that rely on other numext implementations
+namespace internal {
+
+// Specialization for complex types that are not supported by std::expm1.
+template <typename RealScalar>
+struct expm1_impl<std::complex<RealScalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(
+ const std::complex<RealScalar>& x) {
+ EIGEN_STATIC_ASSERT_NON_INTEGER(RealScalar)
+ RealScalar xr = x.real();
+ RealScalar xi = x.imag();
+ // expm1(z) = exp(z) - 1
+ // = exp(x + i * y) - 1
+ // = exp(x) * (cos(y) + i * sin(y)) - 1
+ // = exp(x) * cos(y) - 1 + i * exp(x) * sin(y)
+ // Imag(expm1(z)) = exp(x) * sin(y)
+ // Real(expm1(z)) = exp(x) * cos(y) - 1
+ // = exp(x) * cos(y) - 1.
+ // = expm1(x) + exp(x) * (cos(y) - 1)
+ // = expm1(x) + exp(x) * (2 * sin(y / 2) ** 2)
+ RealScalar erm1 = numext::expm1<RealScalar>(xr);
+ RealScalar er = erm1 + RealScalar(1.);
+ RealScalar sin2 = numext::sin(xi / RealScalar(2.));
+ sin2 = sin2 * sin2;
+ RealScalar s = numext::sin(xi);
+ RealScalar real_part = erm1 - RealScalar(2.) * er * sin2;
+ return std::complex<RealScalar>(real_part, er * s);
+ }
+};
+
+template<typename T>
+struct rsqrt_impl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_ALWAYS_INLINE T run(const T& x) {
+ return T(1)/numext::sqrt(x);
+ }
+};
} // end namespace internal
diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h
index 9222285b4..0d3f317bb 100644
--- a/Eigen/src/Core/MathFunctionsImpl.h
+++ b/Eigen/src/Core/MathFunctionsImpl.h
@@ -79,6 +79,12 @@ template<typename RealScalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y)
{
+ // IEEE IEC 6059 special cases.
+ if ((numext::isinf)(x) || (numext::isinf)(y))
+ return NumTraits<RealScalar>::infinity();
+ if ((numext::isnan)(x) || (numext::isnan)(y))
+ return NumTraits<RealScalar>::quiet_NaN();
+
EIGEN_USING_STD(sqrt);
RealScalar p, qp;
p = numext::maxi(x,y);
@@ -128,20 +134,56 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
const T x = numext::real(z);
const T y = numext::imag(z);
const T zero = T(0);
- const T cst_half = T(0.5);
+ const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
- // Special case of isinf(y)
- if ((numext::isinf)(y)) {
- return std::complex<T>(std::numeric_limits<T>::infinity(), y);
- }
-
- T w = numext::sqrt(cst_half * (numext::abs(x) + numext::abs(z)));
return
- x == zero ? std::complex<T>(w, y < zero ? -w : w)
- : x > zero ? std::complex<T>(w, y / (2 * w))
+ (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
+ : x == zero ? std::complex<T>(w, y < zero ? -w : w)
+ : x > zero ? std::complex<T>(w, y / (2 * w))
: std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w );
}
+// Generic complex rsqrt implementation.
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
+ // Computes the principal reciprocal sqrt of the input.
+ //
+ // For a complex reciprocal square root of the number z = x + i*y. We want to
+ // find real numbers u and v such that
+ // (u + i*v)^2 = 1 / (x + i*y) <=>
+ // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2.
+ // By equating the real and imaginary parts we get:
+ // u^2 - v^2 = x/|z|^2
+ // 2*u*v = y/|z|^2.
+ //
+ // For x >= 0, this has the numerically stable solution
+ // u = sqrt(0.5 * (x + |z|)) / |z|
+ // v = -y / (2 * u * |z|)
+ // and for x < 0,
+ // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z|
+ // u = -y / (2 * v * |z|)
+ //
+ // Letting w = sqrt(0.5 * (|x| + |z|)),
+ // if x == 0: u = w / |z|, v = -sign(y) * w / |z|
+ // if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
+ // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
+
+ const T x = numext::real(z);
+ const T y = numext::imag(z);
+ const T zero = T(0);
+
+ const T abs_z = numext::hypot(x, y);
+ const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
+ const T woz = w / abs_z;
+ // Corner cases consistent with 1/sqrt(z) on gcc/clang.
+ return
+ abs_z == zero ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
+ : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
+ : x == zero ? std::complex<T>(woz, y < zero ? woz : -woz)
+ : x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
+ : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz );
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h
index df5a3c2a4..6e77372b0 100644
--- a/Eigen/src/Core/arch/CUDA/Complex.h
+++ b/Eigen/src/Core/arch/CUDA/Complex.h
@@ -94,19 +94,6 @@ template<typename T> struct scalar_quotient_op<const std::complex<T>, const std:
template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
-// Complex sqrt is already specialized on Windows.
-#if EIGEN_COMP_MSVC == 0
-template<typename T>
-struct sqrt_impl<std::complex<T> >
-{
- EIGEN_DEVICE_FUNC
- static EIGEN_ALWAYS_INLINE std::complex<T> run(const std::complex<T>& x)
- {
- return complex_sqrt<T>(x);
- }
-};
-#endif
-
} // namespace internal
} // namespace Eigen
diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h
index 9f66d8ab3..8736d0d6b 100644
--- a/Eigen/src/Core/arch/SSE/MathFunctions.h
+++ b/Eigen/src/Core/arch/SSE/MathFunctions.h
@@ -150,7 +150,7 @@ Packet4f prsqrt<Packet4f>(const Packet4f& _x) {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
+ // Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
@@ -158,7 +158,6 @@ Packet4f prsqrt<Packet4f>(const Packet4f& x) {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
- // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index eee6ae194..976ecba59 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -456,7 +456,7 @@ struct functor_traits<scalar_sqrt_op<bool> > {
*/
template<typename Scalar> struct scalar_rsqrt_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
- EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(1)/numext::sqrt(a); }
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::rsqrt(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
};