diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 82 |
1 files changed, 81 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 8491c4ca2..5f2e329f2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -219,6 +219,86 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX namespace internal { +template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> +struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > +{ + // Type promotion to handle the case where the types of the args are different. + typedef typename result_of< + TernaryOp(typename Arg1XprType::Scalar, + typename Arg2XprType::Scalar, + typename Arg3XprType::Scalar)>::type Scalar; + typedef traits<Arg1XprType> XprTraits; + typedef typename traits<Arg1XprType>::StorageKind StorageKind; + typedef typename traits<Arg1XprType>::Index Index; + typedef typename Arg1XprType::Nested Arg1Nested; + typedef typename Arg2XprType::Nested Arg2Nested; + typedef typename Arg3XprType::Nested Arg3Nested; + typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; + typedef typename remove_reference<Arg2Nested>::type _Arg2Nested; + typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; + + enum { + Flags = 0 + }; +}; + +template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> +struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> +{ + typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type; +}; + +template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> +struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> +{ + typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type; +}; + +} // end namespace internal + + + +template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> +class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> +{ + public: + typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef Scalar CoeffReturnType; + typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested; + typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) + : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} + + EIGEN_DEVICE_FUNC + const TernaryOp& functor() const { return m_functor; } + + /** \returns the nested expressions */ + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename Arg1XprType::Nested>::type& + arg1Expression() const { return m_arg1_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename Arg1XprType::Nested>::type& + arg2Expression() const { return m_arg2_xpr; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename Arg3XprType::Nested>::type& + arg3Expression() const { return m_arg3_xpr; } + + protected: + typename Arg1XprType::Nested m_arg1_xpr; + typename Arg1XprType::Nested m_arg2_xpr; + typename Arg3XprType::Nested m_arg3_xpr; + const TernaryOp m_functor; +}; + + +namespace internal { template<typename IfXprType, typename ThenXprType, typename ElseXprType> struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > : traits<ThenXprType> @@ -252,7 +332,7 @@ struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename e template<typename IfXprType, typename ThenXprType, typename ElseXprType> -class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > +class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; |