From 7e6a1c129c201db4eff46f4dd68acdc7e935eaf2 Mon Sep 17 00:00:00 2001 From: Nathan Luehr Date: Mon, 19 Apr 2021 18:05:27 -0500 Subject: Device implementation of log for std::complex types. --- Eigen/src/Core/MathFunctions.h | 30 +++++++++++++++++++++++++++--- Eigen/src/Core/MathFunctionsImpl.h | 9 +++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 7f82090a9..d7ac4d64d 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2006-2010 Benoit Jacob +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -687,6 +688,30 @@ struct expm1_retval typedef Scalar type; }; +/**************************************************************************** +* Implementation of log * +****************************************************************************/ + +// Complex log defined in MathFunctionsImpl.h. +template EIGEN_DEVICE_FUNC std::complex complex_log(const std::complex& z); + +template +struct log_impl { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x) + { + EIGEN_USING_STD(log); + return static_cast(log(x)); + } +}; + +template +struct log_impl > { + EIGEN_DEVICE_FUNC static inline std::complex run(const std::complex& z) + { + return complex_log(z); + } +}; + /**************************************************************************** * Implementation of log1p * ****************************************************************************/ @@ -700,7 +725,7 @@ namespace std_fallback { typedef typename NumTraits::Real RealScalar; EIGEN_USING_STD(log); Scalar x1p = RealScalar(1) + x; - Scalar log_1p = log(x1p); + Scalar log_1p = log_impl::run(x1p); const bool is_small = numext::equal_strict(x1p, Scalar(1)); const bool is_inf = numext::equal_strict(x1p, log_1p); return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1))); @@ -1460,8 +1485,7 @@ T rsqrt(const T& x) template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { - EIGEN_USING_STD(log); - return static_cast(log(x)); + return internal::log_impl::run(x); } #if defined(SYCL_DEVICE_ONLY) diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 0d3f317bb..4eaaaa784 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -184,6 +184,15 @@ EIGEN_DEVICE_FUNC std::complex complex_rsqrt(const std::complex& z) { : std::complex(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz ); } +template +EIGEN_DEVICE_FUNC std::complex complex_log(const std::complex& z) { + // Computes complex log. + T a = numext::abs(z); + EIGEN_USING_STD(atan2); + T b = atan2(z.imag(), z.real()); + return std::complex(numext::log(a), b); +} + } // end namespace internal } // end namespace Eigen -- cgit v1.2.3