From 7e6a1c129c201db4eff46f4dd68acdc7e935eaf2 Mon Sep 17 00:00:00 2001 From: Nathan Luehr Date: Mon, 19 Apr 2021 18:05:27 -0500 Subject: Device implementation of log for std::complex types. --- Eigen/src/Core/MathFunctions.h | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) (limited to 'Eigen/src/Core/MathFunctions.h') diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 7f82090a9..d7ac4d64d 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2006-2010 Benoit Jacob +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -687,6 +688,30 @@ struct expm1_retval typedef Scalar type; }; +/**************************************************************************** +* Implementation of log * +****************************************************************************/ + +// Complex log defined in MathFunctionsImpl.h. +template EIGEN_DEVICE_FUNC std::complex complex_log(const std::complex& z); + +template +struct log_impl { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x) + { + EIGEN_USING_STD(log); + return static_cast(log(x)); + } +}; + +template +struct log_impl > { + EIGEN_DEVICE_FUNC static inline std::complex run(const std::complex& z) + { + return complex_log(z); + } +}; + /**************************************************************************** * Implementation of log1p * ****************************************************************************/ @@ -700,7 +725,7 @@ namespace std_fallback { typedef typename NumTraits::Real RealScalar; EIGEN_USING_STD(log); Scalar x1p = RealScalar(1) + x; - Scalar log_1p = log(x1p); + Scalar log_1p = log_impl::run(x1p); const bool is_small = numext::equal_strict(x1p, Scalar(1)); const bool is_inf = numext::equal_strict(x1p, log_1p); return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1))); @@ -1460,8 +1485,7 @@ T rsqrt(const T& x) template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { - EIGEN_USING_STD(log); - return static_cast(log(x)); + return internal::log_impl::run(x); } #if defined(SYCL_DEVICE_ONLY) -- cgit v1.2.3