diff options
author | Nathan Luehr <nluehr@nvidia.com> | 2021-04-19 18:05:27 -0500 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-05-11 22:02:21 +0000 |
commit | 7e6a1c129c201db4eff46f4dd68acdc7e935eaf2 (patch) | |
tree | ce3662dea2b4b5329e89f7dc75f7bb1918bbff9b /Eigen/src/Core/MathFunctions.h | |
parent | 6753f0f197e7b8a8019e82e7b144ac0281d6a7f1 (diff) |
Device implementation of log for std::complex types.
Diffstat (limited to 'Eigen/src/Core/MathFunctions.h')
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 7f82090a9..d7ac4d64d 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2006-2010 Benoit Jacob <jacob.benoit.1@gmail.com> +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -688,6 +689,30 @@ struct expm1_retval }; /**************************************************************************** +* Implementation of log * +****************************************************************************/ + +// Complex log defined in MathFunctionsImpl.h. +template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z); + +template<typename Scalar> +struct log_impl { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x) + { + EIGEN_USING_STD(log); + return static_cast<Scalar>(log(x)); + } +}; + +template<typename Scalar> +struct log_impl<std::complex<Scalar> > { + EIGEN_DEVICE_FUNC static inline std::complex<Scalar> run(const std::complex<Scalar>& z) + { + return complex_log(z); + } +}; + +/**************************************************************************** * Implementation of log1p * ****************************************************************************/ @@ -700,7 +725,7 @@ namespace std_fallback { typedef typename NumTraits<Scalar>::Real RealScalar; EIGEN_USING_STD(log); Scalar x1p = RealScalar(1) + x; - Scalar log_1p = log(x1p); + Scalar log_1p = log_impl<Scalar>::run(x1p); const bool is_small = numext::equal_strict(x1p, Scalar(1)); const bool is_inf = numext::equal_strict(x1p, log_1p); return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1))); @@ -1460,8 +1485,7 @@ T rsqrt(const T& x) template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { - EIGEN_USING_STD(log); - return static_cast<T>(log(x)); + return internal::log_impl<T>::run(x); } #if defined(SYCL_DEVICE_ONLY) |