From 39baff850c2f4fe1fee3b7a3918ba62a526e4f08 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 2 Jun 2016 17:04:19 -0700 Subject: Add TernaryFunctors and the betainc SpecialFunction. TernaryFunctors and their executors allow operations on 3-tuples of inputs. API fully implemented for Arrays and Tensors based on binary functors. Ported the cephes betainc function (regularized incomplete beta integral) to Eigen, with support for CPU and GPU, floats, doubles, and half types. Added unit tests in array.cpp and cxx11_tensor_cuda.cu Collapsed revision * Merged helper methods for betainc across floats and doubles. * Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase. * Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper. * betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup. * Update TernaryOp and SpecialFunctions (betainc) based on review comments. --- unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 96 +++++++++++++++++++++++++ 1 file changed, 96 insertions(+) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index ea250d8bc..9509f8002 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -218,6 +218,102 @@ class TensorCwiseBinaryOp : public TensorBase +struct traits > +{ + // Type promotion to handle the case where the types of the args are different. + typedef typename result_of< + TernaryOp(typename Arg1XprType::Scalar, + typename Arg2XprType::Scalar, + typename Arg3XprType::Scalar)>::type Scalar; + EIGEN_STATIC_ASSERT( + (internal::is_same::StorageKind, + typename traits::StorageKind>::value), + STORAGE_KIND_MUST_MATCH) + EIGEN_STATIC_ASSERT( + (internal::is_same::StorageKind, + typename traits::StorageKind>::value), + STORAGE_KIND_MUST_MATCH) + EIGEN_STATIC_ASSERT( + (internal::is_same::Index, + typename traits::Index>::value), + STORAGE_INDEX_MUST_MATCH) + EIGEN_STATIC_ASSERT( + (internal::is_same::Index, + typename traits::Index>::value), + STORAGE_INDEX_MUST_MATCH) + typedef traits XprTraits; + typedef typename traits::StorageKind StorageKind; + typedef typename traits::Index Index; + typedef typename Arg1XprType::Nested Arg1Nested; + typedef typename Arg2XprType::Nested Arg2Nested; + typedef typename Arg3XprType::Nested Arg3Nested; + typedef typename remove_reference::type _Arg1Nested; + typedef typename remove_reference::type _Arg2Nested; + typedef typename remove_reference::type _Arg3Nested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; + + enum { + Flags = 0 + }; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorCwiseTernaryOp& type; +}; + +template +struct nested, 1, typename eval >::type> +{ + typedef TensorCwiseTernaryOp type; +}; + +} // end namespace internal + + + +template +class TensorCwiseTernaryOp : public TensorBase, ReadOnlyAccessors> +{ + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef Scalar CoeffReturnType; + typedef typename Eigen::internal::nested::type Nested; + typedef typename Eigen::internal::traits::StorageKind StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) + : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} + + EIGEN_DEVICE_FUNC + const TernaryOp& functor() const { return m_functor; } + + /** \returns the nested expressions */ + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + arg1Expression() const { return m_arg1_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + arg2Expression() const { return m_arg2_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + arg3Expression() const { return m_arg3_xpr; } + + protected: + typename Arg1XprType::Nested m_arg1_xpr; + typename Arg1XprType::Nested m_arg2_xpr; + typename Arg3XprType::Nested m_arg3_xpr; + const TernaryOp m_functor; +}; + + namespace internal { template struct traits > -- cgit v1.2.3