diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
commit | 736267cf6b17832a571acf7e34ca07c7f55907ee (patch) | |
tree | 894d0bfd7455b670117a252afad0157ba01a766b /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | |
parent | 7402fea0a8e63e3ea248257047c584afee8f8bde (diff) |
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...)
* selection
* nullary ops such as random or constant generation
* misc unary ops such as log(), exp(), or a user defined unaryExpr()
Cleaned up the code a little.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 3ce924dc3..e0c0863b7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -68,6 +68,42 @@ struct TensorEvaluator +// -------------------- CwiseNullaryOp -------------------- + +template<typename NullaryOp, typename PlainObjectType> +struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > +{ + typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType; + + enum { + IsAligned = true, + PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, + }; + + TensorEvaluator(const XprType& op) + : m_functor(op.functor()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_functor(index); + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + return m_functor.packetOp(index); + } + + private: + const NullaryOp m_functor; +}; + + // -------------------- CwiseUnaryOp -------------------- @@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg TensorEvaluator<RightArgType> m_rightImpl; }; + +// -------------------- SelectOp -------------------- + +template<typename IfArgType, typename ThenArgType, typename ElseArgType> +struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> > +{ + typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType; + + enum { + IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned, + PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* & + TensorEvaluator<IfArgType>::PacketAccess*/, + }; + + TensorEvaluator(const XprType& op) + : m_condImpl(op.ifExpression()), + m_thenImpl(op.thenExpression()), + m_elseImpl(op.elseExpression()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); + } + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; + internal::Selector<PacketSize> select; + for (Index i = 0; i < PacketSize; ++i) { + select.select[i] = m_condImpl.coeff(index+i); + } + return internal::pblend(select, + m_thenImpl.template packet<LoadMode>(index), + m_elseImpl.template packet<LoadMode>(index)); + } + + private: + TensorEvaluator<IfArgType> m_condImpl; + TensorEvaluator<ThenArgType> m_thenImpl; + TensorEvaluator<ElseArgType> m_elseImpl; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H |