aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
committerGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-06-02 17:04:19 -0700
commit39baff850c2f4fe1fee3b7a3918ba62a526e4f08 (patch)
tree841ea12578450cfc0ab3e96a68d4e433a985a01d /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent02db4e1a82e6059cc217d6aa57bcc5ac6342eb37 (diff)
Add TernaryFunctors and the betainc SpecialFunction.
TernaryFunctors and their executors allow operations on 3-tuples of inputs. API fully implemented for Arrays and Tensors based on binary functors. Ported the cephes betainc function (regularized incomplete beta integral) to Eigen, with support for CPU and GPU, floats, doubles, and half types. Added unit tests in array.cpp and cxx11_tensor_cuda.cu Collapsed revision * Merged helper methods for betainc across floats and doubles. * Added TensorGlobalFunctions with betainc(). Removed betainc() from TensorBase. * Clean up CwiseTernaryOp checks, change igamma_helper to cephes_helper. * betainc: merge incbcf and incbd into incbeta_cfe. and more cleanup. * Update TernaryOp and SpecialFunctions (betainc) based on review comments.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h81
1 files changed, 81 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 31b361c83..4e873011e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType, Device> m_rightImpl;
};
+// -------------------- CwiseTernaryOp --------------------
+
+template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
+struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
+ internal::functor_traits<TernaryOp>::PacketAccess,
+ Layout = TensorEvaluator<Arg1Type, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
+ };
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
+ : m_functor(op.functor()),
+ m_arg1Impl(op.arg1Expression(), device),
+ m_arg2Impl(op.arg2Expression(), device),
+ m_arg3Impl(op.arg3Expression(), device)
+ {
+ EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use arg2 or arg3 dimensions if they are known at compile time.
+ return m_arg1Impl.dimensions();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
+ m_arg1Impl.evalSubExprsIfNeeded(NULL);
+ m_arg2Impl.evalSubExprsIfNeeded(NULL);
+ m_arg3Impl.evalSubExprsIfNeeded(NULL);
+ return true;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ m_arg1Impl.cleanup();
+ m_arg2Impl.cleanup();
+ m_arg3Impl.cleanup();
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
+ m_arg2Impl.template packet<LoadMode>(index),
+ m_arg3Impl.template packet<LoadMode>(index));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
+ costPerCoeff(bool vectorized) const {
+ const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
+ return m_arg1Impl.costPerCoeff(vectorized) +
+ m_arg2Impl.costPerCoeff(vectorized) +
+ m_arg3Impl.costPerCoeff(vectorized) +
+ TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
+
+ private:
+ const TernaryOp m_functor;
+ TensorEvaluator<Arg1Type, Device> m_arg1Impl;
+ TensorEvaluator<Arg1Type, Device> m_arg2Impl;
+ TensorEvaluator<Arg3Type, Device> m_arg3Impl;
+};
+
// -------------------- SelectOp --------------------