aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen')
-rw-r--r--unsupported/Eigen/CXX11/Tensor1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h81
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h96
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h1
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h33
6 files changed, 212 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor
index 859147404..79bac2f67 100644
--- a/unsupported/Eigen/CXX11/Tensor
+++ b/unsupported/Eigen/CXX11/Tensor
@@ -80,6 +80,7 @@ typedef unsigned __int64 uint64_t;
#include "src/Tensor/TensorTraits.h"
#include "src/Tensor/TensorUInt128.h"
#include "src/Tensor/TensorIntDiv.h"
+#include "src/Tensor/TensorGlobalFunctions.h"
#include "src/Tensor/TensorBase.h"
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index eafd6f6f1..8f3580ba7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -307,7 +307,6 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_floor_op<Scalar>());
}
-
// Generic binary operation support.
template <typename CustomBinaryOp, typename OtherDerived> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 31b361c83..4e873011e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -403,6 +403,87 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType, Device> m_rightImpl;
};
+// -------------------- CwiseTernaryOp --------------------
+
+template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
+struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
+ internal::functor_traits<TernaryOp>::PacketAccess,
+ Layout = TensorEvaluator<Arg1Type, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = false
+ };
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
+ : m_functor(op.functor()),
+ m_arg1Impl(op.arg1Expression(), device),
+ m_arg2Impl(op.arg2Expression(), device),
+ m_arg3Impl(op.arg3Expression(), device)
+ {
+ EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
+ typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use arg2 or arg3 dimensions if they are known at compile time.
+ return m_arg1Impl.dimensions();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
+ m_arg1Impl.evalSubExprsIfNeeded(NULL);
+ m_arg2Impl.evalSubExprsIfNeeded(NULL);
+ m_arg3Impl.evalSubExprsIfNeeded(NULL);
+ return true;
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
+ m_arg1Impl.cleanup();
+ m_arg2Impl.cleanup();
+ m_arg3Impl.cleanup();
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
+ m_arg2Impl.template packet<LoadMode>(index),
+ m_arg3Impl.template packet<LoadMode>(index));
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
+ costPerCoeff(bool vectorized) const {
+ const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
+ return m_arg1Impl.costPerCoeff(vectorized) +
+ m_arg2Impl.costPerCoeff(vectorized) +
+ m_arg3Impl.costPerCoeff(vectorized) +
+ TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
+ }
+
+ EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
+
+ private:
+ const TernaryOp m_functor;
+ TensorEvaluator<Arg1Type, Device> m_arg1Impl;
+ TensorEvaluator<Arg1Type, Device> m_arg2Impl;
+ TensorEvaluator<Arg3Type, Device> m_arg3Impl;
+};
+
// -------------------- SelectOp --------------------
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index ea250d8bc..9509f8002 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -219,6 +219,102 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
namespace internal {
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
+{
+ // Type promotion to handle the case where the types of the args are different.
+ typedef typename result_of<
+ TernaryOp(typename Arg1XprType::Scalar,
+ typename Arg2XprType::Scalar,
+ typename Arg3XprType::Scalar)>::type Scalar;
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg2XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::StorageKind,
+ typename traits<Arg3XprType>::StorageKind>::value),
+ STORAGE_KIND_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg2XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ EIGEN_STATIC_ASSERT(
+ (internal::is_same<typename traits<Arg1XprType>::Index,
+ typename traits<Arg3XprType>::Index>::value),
+ STORAGE_INDEX_MUST_MATCH)
+ typedef traits<Arg1XprType> XprTraits;
+ typedef typename traits<Arg1XprType>::StorageKind StorageKind;
+ typedef typename traits<Arg1XprType>::Index Index;
+ typedef typename Arg1XprType::Nested Arg1Nested;
+ typedef typename Arg2XprType::Nested Arg2Nested;
+ typedef typename Arg3XprType::Nested Arg3Nested;
+ typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
+
+ enum {
+ Flags = 0
+ };
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
+{
+ typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
+};
+
+} // end namespace internal
+
+
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef Scalar CoeffReturnType;
+ typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
+ : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const TernaryOp& functor() const { return m_functor; }
+
+ /** \returns the nested expressions */
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg1Expression() const { return m_arg1_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg2Expression() const { return m_arg2_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg3XprType::Nested>::type&
+ arg3Expression() const { return m_arg3_xpr; }
+
+ protected:
+ typename Arg1XprType::Nested m_arg1_xpr;
+ typename Arg1XprType::Nested m_arg2_xpr;
+ typename Arg3XprType::Nested m_arg3_xpr;
+ const TernaryOp m_functor;
+};
+
+
+namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index a1a18d938..f35275ffb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -21,6 +21,7 @@ template<typename Derived, int AccessLevel = internal::accessors_level<Derived>:
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> class TensorCwiseTernaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename Op, typename Dims, typename XprType> class TensorReductionOp;
template<typename XprType> class TensorIndexTupleOp;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
new file mode 100644
index 000000000..665b861cf
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h
@@ -0,0 +1,33 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+#define EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H
+
+namespace Eigen {
+
+/** \cpp11 \returns an expression of the coefficient-wise betainc(\a x, \a a, \a b) to the given tensors.
+ *
+ * This function computes the regularized incomplete beta function (integral).
+ *
+ */
+template <typename ADerived, typename BDerived, typename XDerived>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const
+ TensorCwiseTernaryOp<internal::scalar_betainc_op<typename XDerived::Scalar>,
+ const ADerived, const BDerived, const XDerived>
+ betainc(const ADerived& a, const BDerived& b, const XDerived& x) {
+ return TensorCwiseTernaryOp<
+ internal::scalar_betainc_op<typename XDerived::Scalar>, const ADerived,
+ const BDerived, const XDerived>(
+ a, b, x, internal::scalar_betainc_op<typename XDerived::Scalar>());
+}
+
+} // end namespace Eigen
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_GLOBAL_FUNCTIONS_H