aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:47:46 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:47:46 -0800
commit1ac86001266db55b78086617fb68206b29748919 (patch)
tree77ef4a659d2743390cbed843726cb6a47b1229c0 /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
parent378bdfb7f0c4b2a8eb2b91c2a65f3bc1c57e689e (diff)
Fixed the return type of coefficient wise operations. For example, the abs function returns a floating point value when called on a complex input.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h87
1 files changed, 52 insertions, 35 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index 6e5503de1..b66b3ec2c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -17,14 +17,14 @@ namespace Eigen {
*
* \brief Tensor expression classes.
*
- * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
- * is typically used to generate constants.
+ * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
+ * This is typically used to generate constants.
*
* The TensorCwiseUnaryOp class represents an expression where a unary operator
* (e.g. cwiseSqrt) is applied to an expression.
*
- * The TensorCwiseBinaryOp class represents an expression where a binary operator
- * (e.g. addition) is applied to a lhs and a rhs expression.
+ * The TensorCwiseBinaryOp class represents an expression where a binary
+ * operator (e.g. addition) is applied to a lhs and a rhs expression.
*
*/
namespace internal {
@@ -33,9 +33,12 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
: traits<XprType>
{
typedef typename XprType::Packet Packet;
+ typedef traits<XprType> XprTraits;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
enum {
Flags = 0,
@@ -47,7 +50,7 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
template<typename NullaryOp, typename XprType>
-class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> >
+class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
@@ -81,12 +84,15 @@ template<typename UnaryOp, typename XprType>
struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
: traits<XprType>
{
- typedef typename result_of<
- UnaryOp(typename XprType::Scalar)
- >::type Scalar;
+ // TODO(phli): Add InputScalar, InputPacket. Check references to
+ // current Scalar/Packet to see if the intent is Input or Output.
+ typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
+ typedef traits<XprType> XprTraits;
typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
};
template<typename UnaryOp, typename XprType>
@@ -106,14 +112,16 @@ struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwise
template<typename UnaryOp, typename XprType>
-class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType> >
+class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
{
public:
+ // TODO(phli): Add InputScalar, InputPacket. Check references to
+ // current Scalar/Packet to see if the intent is Input or Output.
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
- typedef typename XprType::CoeffReturnType CoeffReturnType;
- typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef Scalar CoeffReturnType;
+ typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
@@ -139,22 +147,27 @@ namespace internal {
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
{
- // Type promotion to handle the case where the types of the lhs and the rhs are different.
+ // Type promotion to handle the case where the types of the lhs and the rhs
+ // are different.
+ // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
+ // current Scalar/Packet to see if the intent is Inputs or Output.
typedef typename result_of<
- BinaryOp(
- typename LhsXprType::Scalar,
- typename RhsXprType::Scalar
- )
- >::type Scalar;
+ BinaryOp(typename LhsXprType::Scalar,
+ typename RhsXprType::Scalar)>::type Scalar;
+ typedef traits<LhsXprType> XprTraits;
typedef typename internal::packet_traits<Scalar>::type Packet;
- typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
- typename traits<RhsXprType>::StorageKind>::ret StorageKind;
- typedef typename promote_index_type<typename traits<LhsXprType>::Index,
- typename traits<RhsXprType>::Index>::type Index;
+ typedef typename promote_storage_type<
+ typename traits<LhsXprType>::StorageKind,
+ typename traits<RhsXprType>::StorageKind>::ret StorageKind;
+ typedef typename promote_index_type<
+ typename traits<LhsXprType>::Index,
+ typename traits<RhsXprType>::Index>::type Index;
typedef typename LhsXprType::Nested LhsNested;
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
enum {
Flags = 0,
@@ -178,21 +191,22 @@ struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
-class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
+class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
{
public:
- typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
- typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet;
- typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
- typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
- typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
- typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
- typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
- typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
- typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
- : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
+ // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
+ // current Scalar/Packet to see if the intent is Inputs or Output.
+ typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef Scalar CoeffReturnType;
+ typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType;
+ typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
+ : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
EIGEN_DEVICE_FUNC
const BinaryOp& functor() const { return m_functor; }
@@ -219,7 +233,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
: traits<ThenXprType>
{
typedef typename traits<ThenXprType>::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef traits<ThenXprType> XprTraits;
+ typedef typename packet_traits<Scalar>::type Packet;
typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
typename traits<ElseXprType>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<ElseXprType>::Index,
@@ -227,6 +242,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
typedef typename IfXprType::Nested IfNested;
typedef typename ThenXprType::Nested ThenNested;
typedef typename ElseXprType::Nested ElseNested;
+ static const int NumDimensions = XprTraits::NumDimensions;
+ static const int Layout = XprTraits::Layout;
};
template<typename IfXprType, typename ThenXprType, typename ElseXprType>