aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h82
1 files changed, 81 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index 8491c4ca2..5f2e329f2 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -219,6 +219,86 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
namespace internal {
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
+{
+ // Type promotion to handle the case where the types of the args are different.
+ typedef typename result_of<
+ TernaryOp(typename Arg1XprType::Scalar,
+ typename Arg2XprType::Scalar,
+ typename Arg3XprType::Scalar)>::type Scalar;
+ typedef traits<Arg1XprType> XprTraits;
+ typedef typename traits<Arg1XprType>::StorageKind StorageKind;
+ typedef typename traits<Arg1XprType>::Index Index;
+ typedef typename Arg1XprType::Nested Arg1Nested;
+ typedef typename Arg2XprType::Nested Arg2Nested;
+ typedef typename Arg3XprType::Nested Arg3Nested;
+ typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
+ typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
+ typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
+
+ enum {
+ Flags = 0
+ };
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
+{
+ typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
+};
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
+{
+ typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
+};
+
+} // end namespace internal
+
+
+
+template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
+class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef Scalar CoeffReturnType;
+ typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
+ : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const TernaryOp& functor() const { return m_functor; }
+
+ /** \returns the nested expressions */
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg1Expression() const { return m_arg1_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg1XprType::Nested>::type&
+ arg2Expression() const { return m_arg2_xpr; }
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename Arg3XprType::Nested>::type&
+ arg3Expression() const { return m_arg3_xpr; }
+
+ protected:
+ typename Arg1XprType::Nested m_arg1_xpr;
+ typename Arg1XprType::Nested m_arg2_xpr;
+ typename Arg3XprType::Nested m_arg3_xpr;
+ const TernaryOp m_functor;
+};
+
+
+namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
@@ -252,7 +332,7 @@ struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename e
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
-class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
+class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;