aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
committerGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
commit39baff850c2f4fe1fee3b7a3918ba62a526e4f08 (patch)
tree841ea12578450cfc0ab3e96a68d4e433a985a01d /unsupported/Eigen
parent02db4e1a82e6059cc217d6aa57bcc5ac6342eb37 (diff)
Add TernaryFunctors and the betainc SpecialFunction.
TernaryFunctors and their executors allow operations on 3-tuples of inputs. API fully implemented for Arrays and Tensors based on binary functors. Ported the cephes betainc function (regularized incomplete beta integral) to Eigen, with support for CPU and GPU, floats, doubles, and half types. Added unit tests in array.cpp and cxx11_tensor_cuda.cu Collapsed revision * Merged helper methods for betainc across floats and doubles. * Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase. * Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper. * betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup. * Update TernaryOp and SpecialFunctions (betainc) based on review comments.
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r--unsupported/Eigen/CXX11/Tensor1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h81
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h96
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h33
6 files changed, 212 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor
index 859147404..79bac2f67 100644
--- a/unsupported/Eigen/CXX11/Tensor
+++ b/unsupported/Eigen/CXX11/Tensor
@@ -80,6 +80,7 @@ typedef unsigned __int64 uint64_t;
#include "src/Tensor/TensorTraits.h"
#include "src/Tensor/TensorUInt128.h"
#include "src/Tensor/TensorIntDiv.h"
+#include "src/Tensor/TensorGlobalFunctions.h"
#include "src/Tensor/TensorBase.h"
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index eafd6f6f1..8f3580ba7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -307,7 +307,6 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_floor_op<Scalar>());
}
-
// Generic binary operation support.
template <typename CustomBinaryOp, typename OtherDerived> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 31b361c83..4e873011e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType, Device> m_rightImpl;
};
+// -------------------- CwiseTernaryOp --------------------
+
+template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
+struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
+ internal::functor_traits<TernaryOp>::PacketAccess,
+ Layout = TensorEvaluator<Arg1Type, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
+ };
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
+ : m_functor(op.functor()),
+ m_arg1Impl(op.arg1Expression(), device),
+ m_arg2Impl(op.arg2Expression(), device),
+ m_arg3Impl(op.arg3Expression(), device)
+ {
+ EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use arg2 or arg3 dimensions if they are known at compile time.
+ return m_arg1Impl.dimensions();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
+ m_arg1Impl.evalSubExprsIfNeeded(NULL);
+ m_arg2Impl.evalSubExprsIfNeeded(NULL);
+ m_arg3Impl.evalSubExprsIfNeeded(NULL);
+ return true;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ m_arg1Impl.cleanup();
+ m_arg2Impl.cleanup();
+ m_arg3Impl.cleanup();
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
+ m_arg2Impl.template packet<LoadMode>(index),
+ m_arg3Impl.template packet<LoadMode>(index));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
+ costPerCoeff(bool vectorized) const {
+ const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
+ return m_arg1Impl.costPerCoeff(vectorized) +
+ m_arg2Impl.costPerCoeff(vectorized) +
+ m_arg3Impl.costPerCoeff(vectorized) +
+ TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
+
+ private:
+ const TernaryOp m_functor;
+ TensorEvaluator<Arg1Type, Device> m_arg1Impl;
+ TensorEvaluator<Arg1Type, Device> m_arg2Impl;
+ TensorEvaluator<Arg3Type, Device> m_arg3Impl;
+};
+
// -------------------- SelectOp --------------------
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index ea250d8bc..9509f8002 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -219,6 +219,102 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
namespace internal {
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
+{
+ // Type promotion to handle the case where the types of the args are different.
+ typedef typename result_of<
+ TernaryOp(typename Arg1XprType::Scalar,
+ typename Arg2XprType::Scalar,
+ typename Arg3XprType::Scalar)>::type Scalar;
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg2XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg3XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg2XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg3XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ typedef traits<Arg1XprType> XprTraits;
+ typedef typename traits<Arg1XprType>::StorageKind StorageKind;
+ typedef typename traits<Arg1XprType>::Index Index;
+ typedef typename Arg1XprType::Nested Arg1Nested;
+ typedef typename Arg2XprType::Nested Arg2Nested;
+ typedef typename Arg3XprType::Nested Arg3Nested;
+ typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
+
+ enum {
+ Flags = 0
+ };
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
+{
+ typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
+};
+
+} // end namespace internal
+
+
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef Scalar CoeffReturnType;
+ typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
+ : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const TernaryOp& functor() const { return m_functor; }
+
+ /** \returns the nested expressions */
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg1Expression() const { return m_arg1_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg2Expression() const { return m_arg2_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg3XprType::Nested>::type&
+ arg3Expression() const { return m_arg3_xpr; }
+
+ protected:
+ typename Arg1XprType::Nested m_arg1_xpr;
+ typename Arg1XprType::Nested m_arg2_xpr;
+ typename Arg3XprType::Nested m_arg3_xpr;
+ const TernaryOp m_functor;
+};
+
+
+namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index a1a18d938..f35275ffb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -21,6 +21,7 @@ template<typename Derived, int AccessLevel = internal::accessors_level<Derived>:
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> class TensorCwiseTernaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename Op, typename Dims, typename XprType> class TensorReductionOp;
template<typename XprType> class TensorIndexTupleOp;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
new file mode 100644
index 000000000..665b861cf
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
@@ -0,0 +1,33 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+#define EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+
+namespace Eigen {
+
+/** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given tensors.
+ *
+ * This function computes the regularized incomplete beta function (integral).
+ *
+ */
+template <typename ADerived, typename BDerived, typename XDerived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const
+ TensorCwiseTernaryOp<internal::scalar_betainc_op<typename XDerived::Scalar>,
+ const ADerived, const BDerived, const XDerived>
+ betainc(const ADerived& a, const BDerived& b, const XDerived& x) {
+ return TensorCwiseTernaryOp<
+ internal::scalar_betainc_op<typename XDerived::Scalar>, const ADerived,
+ const BDerived, const XDerived>(
+ a, b, x, internal::scalar_betainc_op<typename XDerived::Scalar>());
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H