aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
committerGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
commit39baff850c2f4fe1fee3b7a3918ba62a526e4f08 (patch)
tree841ea12578450cfc0ab3e96a68d4e433a985a01d
parent02db4e1a82e6059cc217d6aa57bcc5ac6342eb37 (diff)
Add TernaryFunctors and the betainc SpecialFunction.
TernaryFunctors and their executors allow operations on 3-tuples of inputs. API fully implemented for Arrays and Tensors based on binary functors. Ported the cephes betainc function (regularized incomplete beta integral) to Eigen, with support for CPU and GPU, floats, doubles, and half types. Added unit tests in array.cpp and cxx11_tensor_cuda.cu Collapsed revision * Merged helper methods for betainc across floats and doubles. * Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase. * Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper. * betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup. * Update TernaryOp and SpecialFunctions (betainc) based on review comments.
-rw-r--r--Eigen/Core2
-rw-r--r--Eigen/src/Core/CoreEvaluators.h99
-rw-r--r--Eigen/src/Core/CwiseTernaryOp.h212
-rw-r--r--Eigen/src/Core/GenericPacketMath.h5
-rw-r--r--Eigen/src/Core/GlobalFunctions.h22
-rw-r--r--Eigen/src/Core/SpecialFunctions.h512
-rw-r--r--Eigen/src/Core/arch/CUDA/Half.h3
-rw-r--r--Eigen/src/Core/arch/CUDA/MathFunctions.h18
-rw-r--r--Eigen/src/Core/arch/CUDA/PacketMath.h6
-rw-r--r--Eigen/src/Core/functors/BinaryFunctors.h2
-rw-r--r--Eigen/src/Core/functors/TernaryFunctors.h47
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h2
-rw-r--r--Eigen/src/Core/util/StaticAssert.h4
-rw-r--r--test/array.cpp113
-rw-r--r--unsupported/Eigen/CXX11/Tensor1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h81
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h96
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h33
-rw-r--r--unsupported/test/cxx11_tensor_cuda.cu150
21 files changed, 1389 insertions, 21 deletions
diff --git a/Eigen/Core b/Eigen/Core
index 25cb87930..1f3a6504e 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -371,6 +371,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/Default/Settings.h"
+#include "src/Core/functors/TernaryFunctors.h"
#include "src/Core/functors/BinaryFunctors.h"
#include "src/Core/functors/UnaryFunctors.h"
#include "src/Core/functors/NullaryFunctors.h"
@@ -403,6 +404,7 @@ using std::ptrdiff_t;
#include "src/Core/PlainObjectBase.h"
#include "src/Core/Matrix.h"
#include "src/Core/Array.h"
+#include "src/Core/CwiseTernaryOp.h"
#include "src/Core/CwiseBinaryOp.h"
#include "src/Core/CwiseUnaryOp.h"
#include "src/Core/CwiseNullaryOp.h"
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index 5a5186a57..7ba92963c 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -41,11 +41,20 @@ template<> struct storage_kind_to_shape<TranspositionsStorage> { typedef Transp
// We currently distinguish the following kind of evaluators:
// - unary_evaluator for expressions taking only one arguments (CwiseUnaryOp, CwiseUnaryView, Transpose, MatrixWrapper, ArrayWrapper, Reverse, Replicate)
// - binary_evaluator for expression taking two arguments (CwiseBinaryOp)
+// - ternary_evaluator for expression taking three arguments (CwiseTernaryOp)
// - product_evaluator for linear algebra products (Product); special case of binary_evaluator because it requires additional tags for dispatching.
// - mapbase_evaluator for Map, Block, Ref
// - block_evaluator for Block (special dispatching to a mapbase_evaluator or unary_evaluator)
template< typename T,
+ typename Arg1Kind = typename evaluator_traits<typename T::Arg1>::Kind,
+ typename Arg2Kind = typename evaluator_traits<typename T::Arg2>::Kind,
+ typename Arg3Kind = typename evaluator_traits<typename T::Arg3>::Kind,
+ typename Arg1Scalar = typename traits<typename T::Arg1>::Scalar,
+ typename Arg2Scalar = typename traits<typename T::Arg2>::Scalar,
+ typename Arg3Scalar = typename traits<typename T::Arg3>::Scalar> struct ternary_evaluator;
+
+template< typename T,
typename LhsKind = typename evaluator_traits<typename T::Lhs>::Kind,
typename RhsKind = typename evaluator_traits<typename T::Rhs>::Kind,
typename LhsScalar = typename traits<typename T::Lhs>::Scalar,
@@ -442,6 +451,96 @@ protected:
evaluator<ArgType> m_argImpl;
};
+// -------------------- CwiseTernaryOp --------------------
+
+// this is a ternary expression
+template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
+struct evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
+ : public ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
+{
+ typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
+ typedef ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > Base;
+
+ EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
+};
+
+template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
+struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased, IndexBased>
+ : evaluator_base<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >
+{
+ typedef CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> XprType;
+
+ enum {
+ CoeffReadCost = evaluator<Arg1>::CoeffReadCost + evaluator<Arg2>::CoeffReadCost + evaluator<Arg3>::CoeffReadCost + functor_traits<TernaryOp>::Cost,
+
+ Arg1Flags = evaluator<Arg1>::Flags,
+ Arg2Flags = evaluator<Arg2>::Flags,
+ Arg3Flags = evaluator<Arg3>::Flags,
+ SameType = is_same<typename Arg1::Scalar,typename Arg2::Scalar>::value && is_same<typename Arg1::Scalar,typename Arg3::Scalar>::value,
+ StorageOrdersAgree = (int(Arg1Flags)&RowMajorBit)==(int(Arg2Flags)&RowMajorBit) && (int(Arg1Flags)&RowMajorBit)==(int(Arg3Flags)&RowMajorBit),
+ Flags0 = (int(Arg1Flags) | int(Arg2Flags) | int(Arg3Flags)) & (
+ HereditaryBits
+ | (int(Arg1Flags) & int(Arg2Flags) & int(Arg3Flags) &
+ ( (StorageOrdersAgree ? LinearAccessBit : 0)
+ | (functor_traits<TernaryOp>::PacketAccess && StorageOrdersAgree && SameType ? PacketAccessBit : 0)
+ )
+ )
+ ),
+ Flags = (Flags0 & ~RowMajorBit) | (Arg1Flags & RowMajorBit),
+ Alignment = EIGEN_PLAIN_ENUM_MIN(
+ EIGEN_PLAIN_ENUM_MIN(evaluator<Arg1>::Alignment, evaluator<Arg2>::Alignment),
+ evaluator<Arg3>::Alignment)
+ };
+
+ EIGEN_DEVICE_FUNC explicit ternary_evaluator(const XprType& xpr)
+ : m_functor(xpr.functor()),
+ m_arg1Impl(xpr.arg1()),
+ m_arg2Impl(xpr.arg2()),
+ m_arg3Impl(xpr.arg3())
+ {
+ EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<TernaryOp>::Cost);
+ EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
+ }
+
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CoeffReturnType coeff(Index row, Index col) const
+ {
+ return m_functor(m_arg1Impl.coeff(row, col), m_arg2Impl.coeff(row, col), m_arg3Impl.coeff(row, col));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ }
+
+ template<int LoadMode, typename PacketType>
+ EIGEN_STRONG_INLINE
+ PacketType packet(Index row, Index col) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(row, col),
+ m_arg2Impl.template packet<LoadMode,PacketType>(row, col),
+ m_arg3Impl.template packet<LoadMode,PacketType>(row, col));
+ }
+
+ template<int LoadMode, typename PacketType>
+ EIGEN_STRONG_INLINE
+ PacketType packet(Index index) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode,PacketType>(index),
+ m_arg2Impl.template packet<LoadMode,PacketType>(index),
+ m_arg3Impl.template packet<LoadMode,PacketType>(index));
+ }
+
+protected:
+ const TernaryOp m_functor;
+ evaluator<Arg1> m_arg1Impl;
+ evaluator<Arg2> m_arg2Impl;
+ evaluator<Arg3> m_arg3Impl;
+};
+
// -------------------- CwiseBinaryOp --------------------
// this is a binary expression
diff --git a/Eigen/src/Core/CwiseTernaryOp.h b/Eigen/src/Core/CwiseTernaryOp.h
new file mode 100644
index 000000000..fe71c07cf
--- /dev/null
+++ b/Eigen/src/Core/CwiseTernaryOp.h
@@ -0,0 +1,212 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
+// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CWISE_TERNARY_OP_H
+#define EIGEN_CWISE_TERNARY_OP_H
+
+namespace Eigen {
+
+namespace internal {
+template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
+struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
+ // we must not inherit from traits<Arg1> since it has
+ // the potential to cause problems with MSVC
+ typedef typename remove_all<Arg1>::type Ancestor;
+ typedef typename traits<Ancestor>::XprKind XprKind;
+ enum {
+ RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
+ ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
+ MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
+ MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
+ };
+
+ // even though we require Arg1, Arg2, and Arg3 to have the same scalar type
+ // (see CwiseTernaryOp constructor),
+ // we still want to handle the case when the result type is different.
+ typedef typename result_of<TernaryOp(
+ const typename Arg1::Scalar&, const typename Arg2::Scalar&,
+ const typename Arg3::Scalar&)>::type Scalar;
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename internal::traits<Arg1>::StorageKind,
+ typename internal::traits<Arg2>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename internal::traits<Arg1>::StorageKind,
+ typename internal::traits<Arg3>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename internal::traits<Arg1>::StorageIndex,
+ typename internal::traits<Arg3>::StorageIndex>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename internal::traits<Arg1>::StorageIndex,
+ typename internal::traits<Arg3>::StorageIndex>::value),
+ STORAGE_INDEX_MUST_MATCH)
+
+ typedef typename internal::traits<Arg1>::StorageKind StorageKind;
+ typedef typename internal::traits<Arg1>::StorageIndex StorageIndex;
+
+ typedef typename Arg1::Nested Arg1Nested;
+ typedef typename Arg2::Nested Arg2Nested;
+ typedef typename Arg3::Nested Arg3Nested;
+ typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
+ enum { Flags = _Arg1Nested::Flags & RowMajorBit };
+};
+} // end namespace internal
+
+template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
+ typename StorageKind>
+class CwiseTernaryOpImpl;
+
+/** \class CwiseTernaryOp
+ * \ingroup Core_Module
+ *
+ * \brief Generic expression where a coefficient-wise ternary operator is
+ * applied to two expressions
+ *
+ * \tparam TernaryOp template functor implementing the operator
+ * \tparam Arg1Type the type of the first argument
+ * \tparam Arg2Type the type of the second argument
+ * \tparam Arg3Type the type of the third argument
+ *
+ * This class represents an expression where a coefficient-wise ternary
+ * operator is applied to three expressions.
+ * It is the return type of ternary operators, by which we mean only those
+ * ternary operators where
+ * all three arguments are Eigen expressions.
+ * For example, the return type of betainc(matrix1, matrix2, matrix3) is a
+ * CwiseTernaryOp.
+ *
+ * Most of the time, this is the only way that it is used, so you typically
+ * don't have to name
+ * CwiseTernaryOp types explicitly.
+ *
+ * \sa MatrixBase::ternaryExpr(const MatrixBase<Argument2> &, const
+ * MatrixBase<Argument3> &, const CustomTernaryOp &) const, class CwiseBinaryOp,
+ * class CwiseUnaryOp, class CwiseNullaryOp
+ */
+template <typename TernaryOp, typename Arg1Type, typename Arg2Type,
+ typename Arg3Type>
+class CwiseTernaryOp : public CwiseTernaryOpImpl<
+ TernaryOp, Arg1Type, Arg2Type, Arg3Type,
+ typename internal::traits<Arg1Type>::StorageKind>,
+ internal::no_assignment_operator {
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<
+ typename internal::traits<Arg1Type>::StorageKind,
+ typename internal::traits<Arg2Type>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<
+ typename internal::traits<Arg1Type>::StorageKind,
+ typename internal::traits<Arg3Type>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+
+ public:
+ typedef typename internal::remove_all<Arg1Type>::type Arg1;
+ typedef typename internal::remove_all<Arg2Type>::type Arg2;
+ typedef typename internal::remove_all<Arg3Type>::type Arg3;
+
+ typedef typename CwiseTernaryOpImpl<
+ TernaryOp, Arg1Type, Arg2Type, Arg3Type,
+ typename internal::traits<Arg1Type>::StorageKind>::Base Base;
+ EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
+
+ typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
+ typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
+ typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
+ typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
+ const Arg3& a3,
+ const TernaryOp& func = TernaryOp())
+ : m_arg1(a1), m_arg2(a2), m_arg3(a3), m_functor(func) {
+ // require the sizes to match
+ EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
+ EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
+ eigen_assert(a1.rows() == a2.rows() && a1.cols() == a2.cols() &&
+ a1.rows() == a3.rows() && a1.cols() == a3.cols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index rows() const {
+ // return the fixed size type if available to enable compile time
+ // optimizations
+ if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ RowsAtCompileTime == Dynamic &&
+ internal::traits<typename internal::remove_all<Arg2Nested>::type>::
+ RowsAtCompileTime == Dynamic)
+ return m_arg3.rows();
+ else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ RowsAtCompileTime == Dynamic &&
+ internal::traits<typename internal::remove_all<Arg3Nested>::type>::
+ RowsAtCompileTime == Dynamic)
+ return m_arg2.rows();
+ else
+ return m_arg1.rows();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Index cols() const {
+ // return the fixed size type if available to enable compile time
+ // optimizations
+ if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ ColsAtCompileTime == Dynamic &&
+ internal::traits<typename internal::remove_all<Arg2Nested>::type>::
+ ColsAtCompileTime == Dynamic)
+ return m_arg3.cols();
+ else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
+ ColsAtCompileTime == Dynamic &&
+ internal::traits<typename internal::remove_all<Arg3Nested>::type>::
+ ColsAtCompileTime == Dynamic)
+ return m_arg2.cols();
+ else
+ return m_arg1.cols();
+ }
+
+ /** \returns the first argument nested expression */
+ EIGEN_DEVICE_FUNC
+ const _Arg1Nested& arg1() const { return m_arg1; }
+ /** \returns the first argument nested expression */
+ EIGEN_DEVICE_FUNC
+ const _Arg2Nested& arg2() const { return m_arg2; }
+ /** \returns the third argument nested expression */
+ EIGEN_DEVICE_FUNC
+ const _Arg3Nested& arg3() const { return m_arg3; }
+ /** \returns the functor representing the ternary operation */
+ EIGEN_DEVICE_FUNC
+ const TernaryOp& functor() const { return m_functor; }
+
+ protected:
+ Arg1Nested m_arg1;
+ Arg2Nested m_arg2;
+ Arg3Nested m_arg3;
+ const TernaryOp m_functor;
+};
+
+// Generic API dispatcher
+template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
+ typename StorageKind>
+class CwiseTernaryOpImpl
+ : public internal::generic_xpr_base<
+ CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
+ public:
+ typedef typename internal::generic_xpr_base<
+ CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
+};
+
+} // end namespace Eigen
+
+#endif // EIGEN_CWISE_TERNARY_OP_H
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 82fabbb70..76a75dee1 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -83,6 +83,7 @@ struct default_packet_traits
HasErfc = 0,
HasIGamma = 0,
HasIGammac = 0,
+ HasBetaInc = 0,
HasRound = 0,
HasFloor = 0,
@@ -466,6 +467,10 @@ Packet pigamma(const Packet& a, const Packet& x) { using numext::igamma; return
template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet pigammac(const Packet& a, const Packet& x) { using numext::igammac; return igammac(a, x); }
+/** \internal \returns the complementary incomplete gamma function betainc(\a a, \a b, \a x) */
+template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+Packet pbetainc(const Packet& a, const Packet& b,const Packet& x) { using numext::betainc; return betainc(a, b, x); }
+
/***************************************************************************
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/
diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h
index 9c97ccb0e..e489cefec 100644
--- a/Eigen/src/Core/GlobalFunctions.h
+++ b/Eigen/src/Core/GlobalFunctions.h
@@ -213,6 +213,28 @@ namespace Eigen
);
}
+ /** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given arrays.
+ *
+ * This function computes the regularized incomplete beta function (integral).
+ *
+ * \note This function supports only float and double scalar types in c++11 mode. To support other scalar types,
+ * or float/double in non c++11 mode, the user has to provide implementations of betainc(T,T,T) for any scalar
+ * type T to be supported.
+ *
+ * \sa Eigen::betainc(), Eigen::lgamma()
+ */
+ template<typename ArgADerived, typename ArgBDerived, typename ArgXDerived>
+ inline const Eigen::CwiseTernaryOp<Eigen::internal::scalar_betainc_op<typename ArgXDerived::Scalar>, const ArgADerived, const ArgBDerived, const ArgXDerived>
+ betainc(const Eigen::ArrayBase<ArgADerived>& a, const Eigen::ArrayBase<ArgBDerived>& b, const Eigen::ArrayBase<ArgXDerived>& x)
+ {
+ return Eigen::CwiseTernaryOp<Eigen::internal::scalar_betainc_op<typename ArgXDerived::Scalar>, const ArgADerived, const ArgBDerived, const ArgXDerived>(
+ a.derived(),
+ b.derived(),
+ x.derived()
+ );
+ }
+
+
/** \returns an expression of the coefficient-wise zeta(\a x, \a q) to the given arrays.
*
* It returns the Riemann zeta function of two arguments \a x and \a q:
diff --git a/Eigen/src/Core/SpecialFunctions.h b/Eigen/src/Core/SpecialFunctions.h
index f34c7bcda..0dd9a1dc3 100644
--- a/Eigen/src/Core/SpecialFunctions.h
+++ b/Eigen/src/Core/SpecialFunctions.h
@@ -392,17 +392,19 @@ struct igammac_retval {
typedef Scalar type;
};
-// NOTE: igamma_helper is also used to implement zeta
+// NOTE: cephes_helper is also used to implement zeta
template <typename Scalar>
-struct igamma_helper {
+struct cephes_helper {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Scalar biginv() { assert(false && "biginv not supported for this type"); return 0.0; }
};
template <>
-struct igamma_helper<float> {
+struct cephes_helper<float> {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE float machep() {
return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
@@ -412,10 +414,15 @@ struct igamma_helper<float> {
// use epsneg (1.0 - epsneg == 1.0)
return 1.0f / (NumTraits<float>::epsilon() / 2);
}
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE float biginv() {
+ // epsneg
+ return machep();
+ }
};
template <>
-struct igamma_helper<double> {
+struct cephes_helper<double> {
EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE double machep() {
return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
@@ -424,6 +431,11 @@ struct igamma_helper<double> {
static EIGEN_STRONG_INLINE double big() {
return 1.0 / NumTraits<double>::epsilon();
}
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE double biginv() {
+ // inverse of eps
+ return NumTraits<double>::epsilon();
+ }
};
#if !EIGEN_HAS_C99_MATH
@@ -538,10 +550,10 @@ struct igammac_impl {
const Scalar zero = 0;
const Scalar one = 1;
const Scalar two = 2;
- const Scalar machep = igamma_helper<Scalar>::machep();
+ const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
- const Scalar big = igamma_helper<Scalar>::big();
- const Scalar biginv = 1 / big;
+ const Scalar big = cephes_helper<Scalar>::big();
+ const Scalar biginv = cephes_helper<Scalar>::biginv();
const Scalar inf = NumTraits<Scalar>::infinity();
Scalar ans, ax, c, yc, r, t, y, z;
@@ -590,7 +602,9 @@ struct igammac_impl {
qkm2 *= biginv;
qkm1 *= biginv;
}
- if (t <= machep) break;
+ if (t <= machep) {
+ break;
+ }
}
return (ans * ax);
@@ -724,7 +738,7 @@ struct igamma_impl {
EIGEN_DEVICE_FUNC static Scalar Impl(Scalar a, Scalar x) {
const Scalar zero = 0;
const Scalar one = 1;
- const Scalar machep = igamma_helper<Scalar>::machep();
+ const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar maxlog = numext::log(NumTraits<Scalar>::highest());
Scalar ans, ax, c, r;
@@ -746,7 +760,9 @@ struct igamma_impl {
r += one;
c *= x/r;
ans += c;
- if (c/ans <= machep) break;
+ if (c/ans <= machep) {
+ break;
+ }
}
return (ans * ax / a);
@@ -899,7 +915,7 @@ struct zeta_impl {
const Scalar maxnum = NumTraits<Scalar>::infinity();
const Scalar zero = 0.0, half = 0.5, one = 1.0;
- const Scalar machep = igamma_helper<Scalar>::machep();
+ const Scalar machep = cephes_helper<Scalar>::machep();
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
if( x == one )
@@ -947,8 +963,9 @@ struct zeta_impl {
t = a*b/A[i];
s = s + t;
t = numext::abs(t/s);
- if( t < machep )
- return s;
+ if( t < machep ) {
+ break;
+ }
k += one;
a *= x + k;
b /= w;
@@ -1007,6 +1024,467 @@ struct polygamma_impl {
#endif // EIGEN_HAS_C99_MATH
+/************************************************************************************************
+ * Implementation of betainc (incomplete beta integral), based on Cephes but requires C++11/C99 *
+ ************************************************************************************************/
+
+template <typename Scalar>
+struct betainc_retval {
+ typedef Scalar type;
+};
+
+#if !EIGEN_HAS_C99_MATH
+
+template <typename Scalar>
+struct betainc_impl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
+ EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
+ THIS_TYPE_IS_NOT_SUPPORTED);
+ return Scalar(0);
+ }
+};
+
+#else
+
+template <typename Scalar>
+struct betainc_impl {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
+ /* betaincf.c
+ *
+ * Incomplete beta integral
+ *
+ *
+ * SYNOPSIS:
+ *
+ * float a, b, x, y, betaincf();
+ *
+ * y = betaincf( a, b, x );
+ *
+ *
+ * DESCRIPTION:
+ *
+ * Returns incomplete beta integral of the arguments, evaluated
+ * from zero to x. The function is defined as
+ *
+ * x
+ * - -
+ * | (a+b) | | a-1 b-1
+ * ----------- | t (1-t) dt.
+ * - - | |
+ * | (a) | (b) -
+ * 0
+ *
+ * The domain of definition is 0 <= x <= 1. In this
+ * implementation a and b are restricted to positive values.
+ * The integral from x to 1 may be obtained by the symmetry
+ * relation
+ *
+ * 1 - betainc( a, b, x ) = betainc( b, a, 1-x ).
+ *
+ * The integral is evaluated by a continued fraction expansion.
+ * If a < 1, the function calls itself recursively after a
+ * transformation to increase a to a+1.
+ *
+ * ACCURACY (float):
+ *
+ * Tested at random points (a,b,x) with a and b in the indicated
+ * interval and x between 0 and 1.
+ *
+ * arithmetic domain # trials peak rms
+ * Relative error:
+ * IEEE 0,30 10000 3.7e-5 5.1e-6
+ * IEEE 0,100 10000 1.7e-4 2.5e-5
+ * The useful domain for relative error is limited by underflow
+ * of the single precision exponential function.
+ * Absolute error:
+ * IEEE 0,30 100000 2.2e-5 9.6e-7
+ * IEEE 0,100 10000 6.5e-5 3.7e-6
+ *
+ * Larger errors may occur for extreme ratios of a and b.
+ *
+ * ACCURACY (double):
+ * arithmetic domain # trials peak rms
+ * IEEE 0,5 10000 6.9e-15 4.5e-16
+ * IEEE 0,85 250000 2.2e-13 1.7e-14
+ * IEEE 0,1000 30000 5.3e-12 6.3e-13
+ * IEEE 0,10000 250000 9.3e-11 7.1e-12
+ * IEEE 0,100000 10000 8.7e-10 4.8e-11
+ * Outputs smaller than the IEEE gradual underflow threshold
+ * were excluded from these statistics.
+ *
+ * ERROR MESSAGES:
+ * message condition value returned
+ * incbet domain x<0, x>1 nan
+ * incbet underflow nan
+ */
+
+ EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
+ THIS_TYPE_IS_NOT_SUPPORTED);
+ return Scalar(0);
+ }
+};
+
+/* Continued fraction expansion #1 for incomplete beta integral (small_branch = True)
+ * Continued fraction expansion #2 for incomplete beta integral (small_branch = False)
+ */
+template <typename Scalar>
+struct incbeta_cfe {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x, bool small_branch) {
+ EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
+ internal::is_same<Scalar, double>::value),
+ THIS_TYPE_IS_NOT_SUPPORTED);
+ const Scalar big = cephes_helper<Scalar>::big();
+ const Scalar machep = cephes_helper<Scalar>::machep();
+ const Scalar biginv = cephes_helper<Scalar>::biginv();
+
+ const Scalar zero = 0;
+ const Scalar one = 1;
+ const Scalar two = 2;
+
+ Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
+ Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
+ Scalar ans;
+ int n;
+
+ const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
+ const Scalar thresh =
+ (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
+ Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
+
+ if (small_branch) {
+ k1 = a;
+ k2 = a + b;
+ k3 = a;
+ k4 = a + one;
+ k5 = one;
+ k6 = b - one;
+ k7 = k4;
+ k8 = a + two;
+ k26update = one;
+ } else {
+ k1 = a;
+ k2 = b - one;
+ k3 = a;
+ k4 = a + one;
+ k5 = one;
+ k6 = a + b;
+ k7 = a + one;
+ k8 = a + two;
+ k26update = -one;
+ x = x / (one - x);
+ }
+
+ pkm2 = zero;
+ qkm2 = one;
+ pkm1 = one;
+ qkm1 = one;
+ ans = one;
+ n = 0;
+
+ do {
+ xk = -(x * k1 * k2) / (k3 * k4);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ xk = (x * k5 * k6) / (k7 * k8);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
+ pkm2 = pkm1;
+ pkm1 = pk;
+ qkm2 = qkm1;
+ qkm1 = qk;
+
+ if (qk != zero) {
+ r = pk / qk;
+ if (numext::abs(ans - r) < numext::abs(r) * thresh) {
+ return r;
+ }
+ ans = r;
+ }
+
+ k1 += one;
+ k2 += k26update;
+ k3 += two;
+ k4 += two;
+ k5 += one;
+ k6 -= k26update;
+ k7 += two;
+ k8 += two;
+
+ if ((numext::abs(qk) + numext::abs(pk)) > big) {
+ pkm2 *= biginv;
+ pkm1 *= biginv;
+ qkm2 *= biginv;
+ qkm1 *= biginv;
+ }
+ if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
+ pkm2 *= big;
+ pkm1 *= big;
+ qkm2 *= big;
+ qkm1 *= big;
+ }
+ } while (++n < num_iters);
+
+ return ans;
+ }
+};
+
+/* Helper functions depending on the Scalar type */
+template <typename Scalar>
+struct betainc_helper {};
+
+template <>
+struct betainc_helper<float> {
+ /* Core implementation, assumes a large (> 1.0) */
+ EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbsa(float aa, float bb,
+ float xx) {
+ float ans, a, b, t, x, onemx;
+ bool reversed_a_b = false;
+
+ onemx = 1.0f - xx;
+
+ /* see if x is greater than the mean */
+ if (xx > (aa / (aa + bb))) {
+ reversed_a_b = true;
+ a = bb;
+ b = aa;
+ t = xx;
+ x = onemx;
+ } else {
+ a = aa;
+ b = bb;
+ t = onemx;
+ x = xx;
+ }
+
+ /* Choose expansion for optimal convergence */
+ if (b > 10.0f) {
+ if (numext::abs(b * x / a) < 0.3f) {
+ t = betainc_helper<float>::incbps(a, b, x);
+ if (reversed_a_b) t = 1.0f - t;
+ return t;
+ }
+ }
+
+ ans = x * (a + b - 2.0f) / (a - 1.0f);
+ if (ans < 1.0f) {
+ ans = incbeta_cfe<float>::run(a, b, x, true /* small_branch */);
+ t = b * numext::log(t);
+ } else {
+ ans = incbeta_cfe<float>::run(a, b, x, false /* small_branch */);
+ t = (b - 1.0f) * numext::log(t);
+ }
+
+ t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
+ lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
+ t += numext::log(ans / a);
+ t = numext::exp(t);
+
+ if (reversed_a_b) t = 1.0f - t;
+ return t;
+ }
+
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE float incbps(float a, float b, float x) {
+ float t, u, y, s;
+ const float machep = cephes_helper<float>::machep();
+
+ y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
+ y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
+ y += lgamma_impl<float>::run(a + b);
+
+ t = x / (1.0f - x);
+ s = 0.0f;
+ u = 1.0f;
+ do {
+ b -= 1.0f;
+ if (b == 0.0f) {
+ break;
+ }
+ a += 1.0f;
+ u *= t * b / a;
+ s += u;
+ } while (numext::abs(u) > machep);
+
+ return numext::exp(y) * (1.0f + s);
+ }
+};
+
+template <>
+struct betainc_impl<float> {
+ EIGEN_DEVICE_FUNC
+ static float run(float a, float b, float x) {
+ const float nan = NumTraits<float>::quiet_NaN();
+ float ans, t;
+
+ if (a <= 0.0f) return nan;
+ if (b <= 0.0f) return nan;
+ if ((x <= 0.0f) || (x >= 1.0f)) {
+ if (x == 0.0f) return 0.0f;
+ if (x == 1.0f) return 1.0f;
+ // mtherr("betaincf", DOMAIN);
+ return nan;
+ }
+
+ /* transformation for small aa */
+ if (a <= 1.0f) {
+ ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
+ t = a * numext::log(x) + b * numext::log1p(-x) +
+ lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
+ lgamma_impl<float>::run(b);
+ return (ans + numext::exp(t));
+ } else {
+ return betainc_helper<float>::incbsa(a, b, x);
+ }
+ }
+};
+
+template <>
+struct betainc_helper<double> {
+ EIGEN_DEVICE_FUNC
+ static EIGEN_STRONG_INLINE double incbps(double a, double b, double x) {
+ const double machep = cephes_helper<double>::machep();
+
+ double s, t, u, v, n, t1, z, ai;
+
+ ai = 1.0 / a;
+ u = (1.0 - b) * x;
+ v = u / (a + 1.0);
+ t1 = v;
+ t = u;
+ n = 2.0;
+ s = 0.0;
+ z = machep * ai;
+ while (numext::abs(v) > z) {
+ u = (n - b) * x / n;
+ t *= u;
+ v = t / (a + n);
+ s += v;
+ n += 1.0;
+ }
+ s += t1;
+ s += ai;
+
+ u = a * numext::log(x);
+ // TODO: gamma() is not directly implemented in Eigen.
+ /*
+ if ((a + b) < maxgam && numext::abs(u) < maxlog) {
+ t = gamma(a + b) / (gamma(a) * gamma(b));
+ s = s * t * pow(x, a);
+ } else {
+ */
+ t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
+ lgamma_impl<double>::run(b) + u + numext::log(s);
+ return s = exp(t);
+ }
+};
+
+template <>
+struct betainc_impl<double> {
+ EIGEN_DEVICE_FUNC
+ static double run(double aa, double bb, double xx) {
+ const double nan = NumTraits<double>::quiet_NaN();
+ const double machep = cephes_helper<double>::machep();
+ // const double maxgam = 171.624376956302725;
+
+ double a, b, t, x, xc, w, y;
+ bool reversed_a_b = false;
+
+ if (aa <= 0.0 || bb <= 0.0) {
+ return nan; // goto domerr;
+ }
+
+ if ((xx <= 0.0) || (xx >= 1.0)) {
+ if (xx == 0.0) return (0.0);
+ if (xx == 1.0) return (1.0);
+ // mtherr("incbet", DOMAIN);
+ return nan;
+ }
+
+ if ((bb * xx) <= 1.0 && xx <= 0.95) {
+ return betainc_helper<double>::incbps(aa, bb, xx);
+ }
+
+ w = 1.0 - xx;
+
+ /* Reverse a and b if x is greater than the mean. */
+ if (xx > (aa / (aa + bb))) {
+ reversed_a_b = true;
+ a = bb;
+ b = aa;
+ xc = xx;
+ x = w;
+ } else {
+ a = aa;
+ b = bb;
+ xc = w;
+ x = xx;
+ }
+
+ if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
+ t = betainc_helper<double>::incbps(a, b, x);
+ if (t <= machep) {
+ t = 1.0 - machep;
+ } else {
+ t = 1.0 - t;
+ }
+ return t;
+ }
+
+ /* Choose expansion for better convergence. */
+ y = x * (a + b - 2.0) - (a - 1.0);
+ if (y < 0.0) {
+ w = incbeta_cfe<double>::run(a, b, x, true /* small_branch */);
+ } else {
+ w = incbeta_cfe<double>::run(a, b, x, false /* small_branch */) / xc;
+ }
+
+ /* Multiply w by the factor
+ a b _ _ _
+ x (1-x) | (a+b) / ( a | (a) | (b) ) . */
+
+ y = a * numext::log(x);
+ t = b * numext::log(xc);
+ // TODO: gamma is not directly implemented in Eigen.
+ /*
+ if ((a + b) < maxgam && numext::abs(y) < maxlog && numext::abs(t) < maxlog)
+ {
+ t = pow(xc, b);
+ t *= pow(x, a);
+ t /= a;
+ t *= w;
+ t *= gamma(a + b) / (gamma(a) * gamma(b));
+ } else {
+ */
+ /* Resort to logarithms. */
+ y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
+ lgamma_impl<double>::run(b);
+ y += numext::log(w / a);
+ t = numext::exp(y);
+
+ /* } */
+ // done:
+
+ if (reversed_a_b) {
+ if (t <= machep) {
+ t = 1.0 - machep;
+ } else {
+ t = 1.0 - t;
+ }
+ }
+ return t;
+ }
+};
+
+#endif // EIGEN_HAS_C99_MATH
+
} // end namespace internal
namespace numext {
@@ -1022,7 +1500,7 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar)
digamma(const Scalar& x) {
return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
}
-
+
template <typename Scalar>
EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar)
zeta(const Scalar& x, const Scalar& q) {
@@ -1059,6 +1537,12 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
}
+template <typename Scalar>
+EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(betainc, Scalar)
+ betainc(const Scalar& a, const Scalar& b, const Scalar& x) {
+ return EIGEN_MATHFUNC_IMPL(betainc, Scalar)::run(a, b, x);
+}
+
} // end namespace numext
diff --git a/Eigen/src/Core/arch/CUDA/Half.h b/Eigen/src/Core/arch/CUDA/Half.h
index d4ce2eaf9..87bdbfd1e 100644
--- a/Eigen/src/Core/arch/CUDA/Half.h
+++ b/Eigen/src/Core/arch/CUDA/Half.h
@@ -480,6 +480,9 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igamma(const Eigen:
template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half igammac(const Eigen::half& a, const Eigen::half& x) {
return Eigen::half(Eigen::numext::igammac(static_cast<float>(a), static_cast<float>(x)));
}
+template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half betainc(const Eigen::half& a, const Eigen::half& b, const Eigen::half& x) {
+ return Eigen::half(Eigen::numext::betainc(static_cast<float>(a), static_cast<float>(b), static_cast<float>(x)));
+}
#endif
} // end namespace numext
diff --git a/Eigen/src/Core/arch/CUDA/MathFunctions.h b/Eigen/src/Core/arch/CUDA/MathFunctions.h
index c90ec96a0..8b5e8204f 100644
--- a/Eigen/src/Core/arch/CUDA/MathFunctions.h
+++ b/Eigen/src/Core/arch/CUDA/MathFunctions.h
@@ -181,6 +181,24 @@ double2 pigammac<double2>(const double2& a, const double2& x)
return make_double2(igammac(a.x, x.x), igammac(a.y, x.y));
}
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+float4 pbetainc<float4>(const float4& a, const float4& b, const float4& x)
+{
+ using numext::betainc;
+ return make_float4(
+ betainc(a.x, b.x, x.x),
+ betainc(a.y, b.y, x.y),
+ betainc(a.z, b.z, x.z),
+ betainc(a.w, b.w, x.w));
+}
+
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+double2 pbetainc<double2>(const double2& a, const double2& b, const double2& x)
+{
+ using numext::betainc;
+ return make_double2(betainc(a.x, b.x, x.x), betainc(a.y, b.y, x.y));
+}
+
#endif
} // end namespace internal
diff --git a/Eigen/src/Core/arch/CUDA/PacketMath.h b/Eigen/src/Core/arch/CUDA/PacketMath.h
index 25a59066c..ad66399e0 100644
--- a/Eigen/src/Core/arch/CUDA/PacketMath.h
+++ b/Eigen/src/Core/arch/CUDA/PacketMath.h
@@ -44,8 +44,9 @@ template<> struct packet_traits<float> : default_packet_traits
HasPolygamma = 1,
HasErf = 1,
HasErfc = 1,
- HasIgamma = 1,
+ HasIGamma = 1,
HasIGammac = 1,
+ HasBetaInc = 1,
HasBlend = 0,
};
@@ -68,10 +69,13 @@ template<> struct packet_traits<double> : default_packet_traits
HasRsqrt = 1,
HasLGamma = 1,
HasDiGamma = 1,
+ HasZeta = 1,
+ HasPolygamma = 1,
HasErf = 1,
HasErfc = 1,
HasIGamma = 1,
HasIGammac = 1,
+ HasBetaInc = 1,
HasBlend = 0,
};
diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h
index 5cd8ca950..6eb5b91ce 100644
--- a/Eigen/src/Core/functors/BinaryFunctors.h
+++ b/Eigen/src/Core/functors/BinaryFunctors.h
@@ -372,7 +372,7 @@ template<typename Scalar> struct scalar_igamma_op {
}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& x) const {
- return internal::pigammac(a, x);
+ return internal::pigamma(a, x);
}
};
template<typename Scalar>
diff --git a/Eigen/src/Core/functors/TernaryFunctors.h b/Eigen/src/Core/functors/TernaryFunctors.h
new file mode 100644
index 000000000..8b9e53062
--- /dev/null
+++ b/Eigen/src/Core/functors/TernaryFunctors.h
@@ -0,0 +1,47 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_TERNARY_FUNCTORS_H
+#define EIGEN_TERNARY_FUNCTORS_H
+
+namespace Eigen {
+
+namespace internal {
+
+//---------- associative ternary functors ----------
+
+/** \internal
+ * \brief Template functor to compute the incomplete beta integral betainc(a, b, x)
+ *
+ */
+template<typename Scalar> struct scalar_betainc_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_betainc_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& x, const Scalar& a, const Scalar& b) const {
+ using numext::betainc; return betainc(x, a, b);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& x, const Packet& a, const Packet& b) const
+ {
+ return internal::pbetainc(x, a, b);
+ }
+};
+template<typename Scalar>
+struct functor_traits<scalar_betainc_op<Scalar> > {
+ enum {
+ // Guesstimate
+ Cost = 400 * NumTraits<Scalar>::MulCost + 400 * NumTraits<Scalar>::AddCost,
+ PacketAccess = packet_traits<Scalar>::HasBetaInc
+ };
+};
+
+} // end namespace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_TERNARY_FUNCTORS_H
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index a102e5457..42e2e75b9 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -91,6 +91,7 @@ template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
+template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
template<typename Decomposition, typename Rhstype> class Solve;
template<typename XprType> class Inverse;
@@ -208,6 +209,7 @@ template<typename Scalar> struct scalar_identity_op;
template<typename Scalar,bool iscpx> struct scalar_sign_op;
template<typename Scalar> struct scalar_igamma_op;
template<typename Scalar> struct scalar_igammac_op;
+template<typename Scalar> struct scalar_betainc_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
template<typename LhsScalar,typename RhsScalar> struct scalar_multiple2_op;
diff --git a/Eigen/src/Core/util/StaticAssert.h b/Eigen/src/Core/util/StaticAssert.h
index 6faaf889a..bee80ca38 100644
--- a/Eigen/src/Core/util/StaticAssert.h
+++ b/Eigen/src/Core/util/StaticAssert.h
@@ -98,7 +98,9 @@
EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT__INVALID_COST_VALUE,
THIS_COEFFICIENT_ACCESSOR_TAKING_ONE_ACCESS_IS_ONLY_FOR_EXPRESSIONS_ALLOWING_LINEAR_ACCESS,
MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY,
- THIS_TYPE_IS_NOT_SUPPORTED
+ THIS_TYPE_IS_NOT_SUPPORTED,
+ STORAGE_KIND_MUST_MATCH,
+ STORAGE_INDEX_MUST_MATCH
};
};
diff --git a/test/array.cpp b/test/array.cpp
index 39a7b856f..8ed1269c2 100644
--- a/test/array.cpp
+++ b/test/array.cpp
@@ -592,16 +592,123 @@ template<typename ArrayType> void array_special_functions()
ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
CALL_SUBTEST( verify_component_wise(ref, ref); );
- if(sizeof(RealScalar)>=64) {
-// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
+ if(sizeof(RealScalar)>=8) { // double
+ // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
+ // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
}
else {
-// CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
+ // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
}
}
#endif
+
+#if EIGEN_HAS_C99_MATH
+ {
+ // Inputs and ground truth generated with scipy via:
+ // a = np.logspace(-3, 3, 5) - 1e-3
+ // b = np.logspace(-3, 3, 5) - 1e-3
+ // x = np.linspace(-0.1, 1.1, 5)
+ // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
+ // full_a = full_a.flatten().tolist() # same for full_b, full_x
+ // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
+ //
+ // Note in Eigen, we call betainc with arguments in the order (x, a, b).
+ ArrayType a(125);
+ ArrayType b(125);
+ ArrayType x(125);
+ ArrayType v(125);
+ ArrayType res(125);
+
+ a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 999.999;
+
+ b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
+ 0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
+ 999.999, 999.999;
+
+ x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
+ 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
+ 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
+ 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
+ -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
+ 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
+ 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
+ 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
+ 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
+ 0.8, 1.1;
+
+ v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
+ nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
+ nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
+ 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
+ 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
+ 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
+ nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
+ 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
+ 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
+ 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
+ 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
+ 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
+ nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
+ 0.0008598571564165444, nan, nan, 6.031987710123844e-08,
+ 0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
+ 0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
+ nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
+ 0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
+ 3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
+ 2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
+
+ CALL_SUBTEST(res = betainc(a, b, x);
+ verify_component_wise(res, v););
+ }
+#endif
}
void test_array()
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor
index 859147404..79bac2f67 100644
--- a/unsupported/Eigen/CXX11/Tensor
+++ b/unsupported/Eigen/CXX11/Tensor
@@ -80,6 +80,7 @@ typedef unsigned __int64 uint64_t;
#include "src/Tensor/TensorTraits.h"
#include "src/Tensor/TensorUInt128.h"
#include "src/Tensor/TensorIntDiv.h"
+#include "src/Tensor/TensorGlobalFunctions.h"
#include "src/Tensor/TensorBase.h"
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index eafd6f6f1..8f3580ba7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -307,7 +307,6 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_floor_op<Scalar>());
}
-
// Generic binary operation support.
template <typename CustomBinaryOp, typename OtherDerived> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 31b361c83..4e873011e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType, Device> m_rightImpl;
};
+// -------------------- CwiseTernaryOp --------------------
+
+template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
+struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
+ internal::functor_traits<TernaryOp>::PacketAccess,
+ Layout = TensorEvaluator<Arg1Type, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
+ };
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
+ : m_functor(op.functor()),
+ m_arg1Impl(op.arg1Expression(), device),
+ m_arg2Impl(op.arg2Expression(), device),
+ m_arg3Impl(op.arg3Expression(), device)
+ {
+ EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use arg2 or arg3 dimensions if they are known at compile time.
+ return m_arg1Impl.dimensions();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
+ m_arg1Impl.evalSubExprsIfNeeded(NULL);
+ m_arg2Impl.evalSubExprsIfNeeded(NULL);
+ m_arg3Impl.evalSubExprsIfNeeded(NULL);
+ return true;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ m_arg1Impl.cleanup();
+ m_arg2Impl.cleanup();
+ m_arg3Impl.cleanup();
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
+ m_arg2Impl.template packet<LoadMode>(index),
+ m_arg3Impl.template packet<LoadMode>(index));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
+ costPerCoeff(bool vectorized) const {
+ const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
+ return m_arg1Impl.costPerCoeff(vectorized) +
+ m_arg2Impl.costPerCoeff(vectorized) +
+ m_arg3Impl.costPerCoeff(vectorized) +
+ TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
+
+ private:
+ const TernaryOp m_functor;
+ TensorEvaluator<Arg1Type, Device> m_arg1Impl;
+ TensorEvaluator<Arg1Type, Device> m_arg2Impl;
+ TensorEvaluator<Arg3Type, Device> m_arg3Impl;
+};
+
// -------------------- SelectOp --------------------
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index ea250d8bc..9509f8002 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -219,6 +219,102 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
namespace internal {
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
+{
+ // Type promotion to handle the case where the types of the args are different.
+ typedef typename result_of<
+ TernaryOp(typename Arg1XprType::Scalar,
+ typename Arg2XprType::Scalar,
+ typename Arg3XprType::Scalar)>::type Scalar;
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg2XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg3XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg2XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg3XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ typedef traits<Arg1XprType> XprTraits;
+ typedef typename traits<Arg1XprType>::StorageKind StorageKind;
+ typedef typename traits<Arg1XprType>::Index Index;
+ typedef typename Arg1XprType::Nested Arg1Nested;
+ typedef typename Arg2XprType::Nested Arg2Nested;
+ typedef typename Arg3XprType::Nested Arg3Nested;
+ typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
+
+ enum {
+ Flags = 0
+ };
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
+{
+ typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
+};
+
+} // end namespace internal
+
+
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef Scalar CoeffReturnType;
+ typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
+ : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const TernaryOp& functor() const { return m_functor; }
+
+ /** \returns the nested expressions */
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg1Expression() const { return m_arg1_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg2Expression() const { return m_arg2_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg3XprType::Nested>::type&
+ arg3Expression() const { return m_arg3_xpr; }
+
+ protected:
+ typename Arg1XprType::Nested m_arg1_xpr;
+ typename Arg1XprType::Nested m_arg2_xpr;
+ typename Arg3XprType::Nested m_arg3_xpr;
+ const TernaryOp m_functor;
+};
+
+
+namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index a1a18d938..f35275ffb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -21,6 +21,7 @@ template<typename Derived, int AccessLevel = internal::accessors_level<Derived>:
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> class TensorCwiseTernaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename Op, typename Dims, typename XprType> class TensorReductionOp;
template<typename XprType> class TensorIndexTupleOp;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
new file mode 100644
index 000000000..665b861cf
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
@@ -0,0 +1,33 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+#define EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+
+namespace Eigen {
+
+/** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given tensors.
+ *
+ * This function computes the regularized incomplete beta function (integral).
+ *
+ */
+template <typename ADerived, typename BDerived, typename XDerived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const
+ TensorCwiseTernaryOp<internal::scalar_betainc_op<typename XDerived::Scalar>,
+ const ADerived, const BDerived, const XDerived>
+ betainc(const ADerived& a, const BDerived& b, const XDerived& x) {
+ return TensorCwiseTernaryOp<
+ internal::scalar_betainc_op<typename XDerived::Scalar>, const ADerived,
+ const BDerived, const XDerived>(
+ a, b, x, internal::scalar_betainc_op<typename XDerived::Scalar>());
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
diff --git a/unsupported/test/cxx11_tensor_cuda.cu b/unsupported/test/cxx11_tensor_cuda.cu
index 4026f48f0..284b46803 100644
--- a/unsupported/test/cxx11_tensor_cuda.cu
+++ b/unsupported/test/cxx11_tensor_cuda.cu
@@ -1019,6 +1019,153 @@ void test_cuda_erfc(const Scalar stddev)
cudaFree(d_out);
}
+template <typename Scalar>
+void test_cuda_betainc()
+{
+ Tensor<Scalar, 1> in_x(125);
+ Tensor<Scalar, 1> in_a(125);
+ Tensor<Scalar, 1> in_b(125);
+ Tensor<Scalar, 1> out(125);
+ Tensor<Scalar, 1> expected_out(125);
+ out.setZero();
+
+ Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
+
+ Array<Scalar, 1, Dynamic> x(125);
+ Array<Scalar, 1, Dynamic> a(125);
+ Array<Scalar, 1, Dynamic> b(125);
+ Array<Scalar, 1, Dynamic> v(125);
+
+ a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999;
+
+ b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
+ 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
+ 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
+ 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
+ 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 999.999,
+ 999.999, 999.999, 999.999;
+
+ x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
+ 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
+ 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
+ 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
+ 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
+ -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
+ 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
+ 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
+ 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1;
+
+ v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
+ nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
+ nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
+ 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
+ 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
+ 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan, nan,
+ nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
+ 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
+ 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
+ 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
+ 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
+ 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06, nan,
+ nan, 7.864342668429763e-23, 3.015969667594166e-10, 0.0008598571564165444,
+ nan, nan, 6.031987710123844e-08, 0.5000000000000007, 0.9999999396801229,
+ nan, nan, 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan,
+ nan, nan, nan, nan, nan, nan, 0.0, 7.029920380986636e-306,
+ 2.2450728208591345e-101, nan, nan, 0.0, 9.275871147869727e-302,
+ 1.2232913026152827e-97, nan, nan, 0.0, 3.0891393081932924e-252,
+ 2.9303043666183996e-60, nan, nan, 2.248913486879199e-196,
+ 0.5000000000004947, 0.9999999999999999, nan;
+
+ for (int i = 0; i < 125; ++i) {
+ in_x(i) = x(i);
+ in_a(i) = a(i);
+ in_b(i) = b(i);
+ expected_out(i) = v(i);
+ }
+
+ std::size_t bytes = in_x.size() * sizeof(Scalar);
+
+ Scalar* d_in_x;
+ Scalar* d_in_a;
+ Scalar* d_in_b;
+ Scalar* d_out;
+ cudaMalloc((void**)(&d_in_x), bytes);
+ cudaMalloc((void**)(&d_in_a), bytes);
+ cudaMalloc((void**)(&d_in_b), bytes);
+ cudaMalloc((void**)(&d_out), bytes);
+
+ cudaMemcpy(d_in_x, in_x.data(), bytes, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_in_a, in_a.data(), bytes, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_in_b, in_b.data(), bytes, cudaMemcpyHostToDevice);
+
+ Eigen::CudaStreamDevice stream;
+ Eigen::GpuDevice gpu_device(&stream);
+
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_x(d_in_x, 125);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_a(d_in_a, 125);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_in_b(d_in_b, 125);
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 1> > gpu_out(d_out, 125);
+
+ gpu_out.device(gpu_device) = betainc(gpu_in_a, gpu_in_b, gpu_in_x);
+
+ assert(cudaMemcpyAsync(out.data(), d_out, bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
+ assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
+
+ for (int i = 1; i < 125; ++i) {
+ if ((std::isnan)(expected_out(i))) {
+ VERIFY((std::isnan)(out(i)));
+ } else {
+ VERIFY_IS_APPROX(out(i), expected_out(i));
+ }
+ }
+
+ cudaFree(d_in_x);
+ cudaFree(d_in_a);
+ cudaFree(d_in_b);
+ cudaFree(d_out);
+}
+
+
void test_cxx11_tensor_cuda()
{
CALL_SUBTEST_1(test_cuda_elementwise_small());
@@ -1086,5 +1233,8 @@ void test_cxx11_tensor_cuda()
CALL_SUBTEST_5(test_cuda_igamma<double>());
CALL_SUBTEST_5(test_cuda_igammac<double>());
+
+ CALL_SUBTEST_6(test_cuda_betainc<float>());
+ CALL_SUBTEST_6(test_cuda_betainc<double>());
#endif
}