aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
commit736267cf6b17832a571acf7e34ca07c7f55907ee (patch)
tree894d0bfd7455b670117a252afad0157ba01a766b /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
parent7402fea0a8e63e3ea248257047c584afee8f8bde (diff)
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...) * selection * nullary ops such as random or constant generation * misc unary ops such as log(), exp(), or a user defined unaryExpr() Cleaned up the code a little.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h109
1 files changed, 109 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index e32077f6e..94cfae05c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -17,6 +17,9 @@ namespace Eigen {
*
* \brief Tensor expression classes.
*
+ * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
+ * is typically used to generate constants.
+ *
* The TensorCwiseUnaryOp class represents an expression where a unary operator
* (e.g. cwiseSqrt) is applied to an expression.
*
@@ -24,6 +27,46 @@ namespace Eigen {
* (e.g. addition) is applied to a lhs and a rhs expression.
*
*/
+namespace internal {
+template<typename NullaryOp, typename PlainObjectType>
+struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+ : traits<PlainObjectType>
+{
+ typedef typename PlainObjectType::Packet Packet;
+ typedef typename PlainObjectType::Scalar Scalar;
+ typedef typename PlainObjectType::Nested XprTypeNested;
+ typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
+};
+
+} // end namespace internal
+
+
+
+template<typename NullaryOp, typename PlainObjectType>
+class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
+ typedef typename PlainObjectType::PacketReturnType PacketReturnType;
+ typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp())
+ : m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const NullaryOp& functor() const { return m_functor; }
+
+ protected:
+ // todo: add tensor dimension to be able to do some sanity checks
+ const NullaryOp m_functor;
+};
+
+
namespace internal {
template<typename UnaryOp, typename XprType>
@@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
const BinaryOp m_functor;
};
+
+namespace internal {
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
+ : traits<ThenXprType>
+{
+ typedef typename traits<ThenXprType>::Scalar Scalar;
+ typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
+ typename traits<ElseXprType>::StorageKind>::ret StorageKind;
+ typedef typename promote_index_type<typename traits<ElseXprType>::Index,
+ typename traits<ThenXprType>::Index>::type Index;
+ typedef typename IfXprType::Nested IfNested;
+ typedef typename ThenXprType::Nested ThenNested;
+ typedef typename ElseXprType::Nested ElseNested;
+};
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
+{
+ typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
+};
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
+{
+ typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
+};
+
+} // end namespace internal
+
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
+ typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
+ typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType,
+ typename ElseXprType::PacketReturnType>::ret PacketReturnType;
+ typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
+
+ TensorSelectOp(const IfXprType& a_condition,
+ const ThenXprType& a_then,
+ const ElseXprType& a_else)
+ : m_condition(a_condition), m_then(a_then), m_else(a_else)
+ { }
+
+ const IfXprType& ifExpression() const { return m_condition; }
+
+ const ThenXprType& thenExpression() const { return m_then; }
+
+ const ElseXprType& elseExpression() const { return m_else; }
+
+ protected:
+ typename IfXprType::Nested m_condition;
+ typename ThenXprType::Nested m_then;
+ typename ElseXprType::Nested m_else;
+};
+
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H