aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
commit736267cf6b17832a571acf7e34ca07c7f55907ee (patch)
tree894d0bfd7455b670117a252afad0157ba01a766b /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent7402fea0a8e63e3ea248257047c584afee8f8bde (diff)
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...) * selection * nullary ops such as random or constant generation * misc unary ops such as log(), exp(), or a user defined unaryExpr() Cleaned up the code a little.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h84
1 files changed, 84 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 3ce924dc3..e0c0863b7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -68,6 +68,42 @@ struct TensorEvaluator
+// -------------------- CwiseNullaryOp --------------------
+
+template<typename NullaryOp, typename PlainObjectType>
+struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+{
+ typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType;
+
+ enum {
+ IsAligned = true,
+ PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
+ };
+
+ TensorEvaluator(const XprType& op)
+ : m_functor(op.functor())
+ { }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(index);
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(index);
+ }
+
+ private:
+ const NullaryOp m_functor;
+};
+
+
// -------------------- CwiseUnaryOp --------------------
@@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType> m_rightImpl;
};
+
+// -------------------- SelectOp --------------------
+
+template<typename IfArgType, typename ThenArgType, typename ElseArgType>
+struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
+{
+ typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
+ PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
+ TensorEvaluator<IfArgType>::PacketAccess*/,
+ };
+
+ TensorEvaluator(const XprType& op)
+ : m_condImpl(op.ifExpression()),
+ m_thenImpl(op.thenExpression()),
+ m_elseImpl(op.elseExpression())
+ { }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
+ {
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ internal::Selector<PacketSize> select;
+ for (Index i = 0; i < PacketSize; ++i) {
+ select.select[i] = m_condImpl.coeff(index+i);
+ }
+ return internal::pblend(select,
+ m_thenImpl.template packet<LoadMode>(index),
+ m_elseImpl.template packet<LoadMode>(index));
+ }
+
+ private:
+ TensorEvaluator<IfArgType> m_condImpl;
+ TensorEvaluator<ThenArgType> m_thenImpl;
+ TensorEvaluator<ElseArgType> m_elseImpl;
+};
+
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H