diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
commit | 736267cf6b17832a571acf7e34ca07c7f55907ee (patch) | |
tree | 894d0bfd7455b670117a252afad0157ba01a766b | |
parent | 7402fea0a8e63e3ea248257047c584afee8f8bde (diff) |
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...)
* selection
* nullary ops such as random or constant generation
* misc unary ops such as log(), exp(), or a user defined unaryExpr()
Cleaned up the code a little.
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 139 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 84 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 109 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMap.h | 36 |
5 files changed, 339 insertions, 31 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index fa1bd3498..8a88ba806 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -33,21 +33,25 @@ class TensorBase Derived& setZero() { return setConstant(Scalar(0)); } - Derived& setConstant(const Scalar& val) { - Scalar* data = derived().data(); - for (int i = 0; i < derived().size(); ++i) { - data[i] = val; - } - return derived(); + return derived() = constant(val); } - Derived& setRandom() { - Scalar* data = derived().data(); - for (int i = 0; i < derived().size(); ++i) { - data[i] = internal::random_default_impl<Scalar, false, false>::run(); - } - return derived(); + return derived() = random(); + } + + // Nullary operators + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> + constant(const Scalar& value) const { + return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> + (internal::scalar_constant_op<Scalar>(value)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived> + random() const { + return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>(); } // Coefficient-wise unary operators @@ -57,15 +61,31 @@ class TensorBase EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> - cwiseSqrt() const { return derived(); } + sqrt() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> + square() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> + inverse() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived> + exp() const { return derived(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived> + log() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> - cwiseAbs() const { return derived(); } + abs() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> - cwisePow(Scalar exponent) const { + pow(Scalar exponent) const { return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> (derived(), internal::scalar_pow_op<Scalar>(exponent)); } @@ -77,6 +97,30 @@ class TensorBase (derived(), internal::scalar_multiple_op<Scalar>(scale)); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + cwiseMax(Scalar threshold) const { + return cwiseMax(constant(threshold)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + cwiseMin(Scalar threshold) const { + return cwiseMin(constant(threshold)); + } + + template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived> + unaryExpr(const CustomUnaryOp& func) const { + return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func); + } + + template <typename NewType> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived> + cast() const { + return derived(); + } + // Coefficient-wise binary operators. template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived> @@ -90,6 +134,71 @@ class TensorBase return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); } + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived> + operator*(const OtherDerived& other) const { + return TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived> + operator/(const OtherDerived& other) const { + return TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived> + cwiseMax(const OtherDerived& other) const { + return TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived> + cwiseMin(const OtherDerived& other) const { + return TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + // Comparisons and tests. + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived> + operator<(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived> + operator<=(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived> + operator>(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived> + operator>=(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived> + operator==(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived> + operator!=(const OtherDerived& other) const { + return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + } + + // Coefficient-wise ternary operators. + template<typename ThenDerived,typename ElseDerived> + inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived> + select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{ + return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived()); + } + + // Select the device on which to evaluate the expression. template <typename DeviceType> TensorDevice<Derived, DeviceType> device(const DeviceType& device) { return TensorDevice<Derived, DeviceType>(device, derived()); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 3ce924dc3..e0c0863b7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -68,6 +68,42 @@ struct TensorEvaluator +// -------------------- CwiseNullaryOp -------------------- + +template<typename NullaryOp, typename PlainObjectType> +struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > +{ + typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType; + + enum { + IsAligned = true, + PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, + }; + + TensorEvaluator(const XprType& op) + : m_functor(op.functor()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_functor(index); + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + return m_functor.packetOp(index); + } + + private: + const NullaryOp m_functor; +}; + + // -------------------- CwiseUnaryOp -------------------- @@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg TensorEvaluator<RightArgType> m_rightImpl; }; + +// -------------------- SelectOp -------------------- + +template<typename IfArgType, typename ThenArgType, typename ElseArgType> +struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> > +{ + typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType; + + enum { + IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned, + PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* & + TensorEvaluator<IfArgType>::PacketAccess*/, + }; + + TensorEvaluator(const XprType& op) + : m_condImpl(op.ifExpression()), + m_thenImpl(op.thenExpression()), + m_elseImpl(op.elseExpression()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const + { + return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); + } + template<int LoadMode> + EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const + { + static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; + internal::Selector<PacketSize> select; + for (Index i = 0; i < PacketSize; ++i) { + select.select[i] = m_condImpl.coeff(index+i); + } + return internal::pblend(select, + m_thenImpl.template packet<LoadMode>(index), + m_elseImpl.template packet<LoadMode>(index)); + } + + private: + TensorEvaluator<IfArgType> m_condImpl; + TensorEvaluator<ThenArgType> m_thenImpl; + TensorEvaluator<ElseArgType> m_elseImpl; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index e32077f6e..94cfae05c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -17,6 +17,9 @@ namespace Eigen { * * \brief Tensor expression classes. * + * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This + * is typically used to generate constants. + * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * @@ -24,6 +27,46 @@ namespace Eigen { * (e.g. addition) is applied to a lhs and a rhs expression. * */ +namespace internal { +template<typename NullaryOp, typename PlainObjectType> +struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > + : traits<PlainObjectType> +{ + typedef typename PlainObjectType::Packet Packet; + typedef typename PlainObjectType::Scalar Scalar; + typedef typename PlainObjectType::Nested XprTypeNested; + typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; +}; + +} // end namespace internal + + + +template<typename NullaryOp, typename PlainObjectType> +class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > +{ + public: + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; + typedef typename PlainObjectType::PacketReturnType PacketReturnType; + typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) + : m_functor(func) {} + + EIGEN_DEVICE_FUNC + const NullaryOp& functor() const { return m_functor; } + + protected: + // todo: add tensor dimension to be able to do some sanity checks + const NullaryOp m_functor; +}; + + namespace internal { template<typename UnaryOp, typename XprType> @@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX const BinaryOp m_functor; }; + +namespace internal { +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > + : traits<ThenXprType> +{ + typedef typename traits<ThenXprType>::Scalar Scalar; + typedef typename internal::packet_traits<Scalar>::type Packet; + typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, + typename traits<ElseXprType>::StorageKind>::ret StorageKind; + typedef typename promote_index_type<typename traits<ElseXprType>::Index, + typename traits<ThenXprType>::Index>::type Index; + typedef typename IfXprType::Nested IfNested; + typedef typename ThenXprType::Nested ThenNested; + typedef typename ElseXprType::Nested ElseNested; +}; + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> +{ + typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; +}; + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> +{ + typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; +}; + +} // end namespace internal + + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > +{ + public: + typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, + typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; + typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType, + typename ElseXprType::PacketReturnType>::ret PacketReturnType; + typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; + typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; + + TensorSelectOp(const IfXprType& a_condition, + const ThenXprType& a_then, + const ElseXprType& a_else) + : m_condition(a_condition), m_then(a_then), m_else(a_else) + { } + + const IfXprType& ifExpression() const { return m_condition; } + + const ThenXprType& thenExpression() const { return m_then; } + + const ElseXprType& elseExpression() const { return m_else; } + + protected: + typename IfXprType::Nested m_condition; + typename ThenXprType::Nested m_then; + typename ElseXprType::Nested m_else; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 09b0fe66d..03ac8d516 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -17,8 +17,10 @@ template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFi template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap; template<typename Derived> class TensorBase; +template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp; template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp; template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp; +template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp; template<typename ExpressionType, typename DeviceType> class TensorDevice; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h index 3fc9c5335..3a2ff5b30 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h @@ -45,33 +45,37 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor static const int Options = Options_; + static const std::size_t NumIndices = PlainObjectType::NumIndices; + typedef typename PlainObjectType::Dimensions Dimensions; + + enum { IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned), PacketAccess = true, }; EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>(firstDimension)) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. - EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) + EIGEN_STATIC_ASSERT(1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #ifdef EIGEN_HAS_VARIADIC_TEMPLATES template<typename... IndexTypes> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) { + EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>({{firstDimension, otherDimensions...}})) { // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. - EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) + EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) } #endif - inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions) + inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions) : m_data(dataPtr), m_dimensions(dimensions) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; } + EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } EIGEN_DEVICE_FUNC @@ -80,7 +84,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) const + EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const { // eigen_assert(checkIndexRange(indices)); if (PlainObjectType::Options&RowMajor) { @@ -96,12 +100,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor template<typename... IndexTypes> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const { - static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); if (PlainObjectType::Options&RowMajor) { - const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}}); return m_data[index]; } else { - const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}}); return m_data[index]; } } @@ -159,7 +163,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor #endif EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) + EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices) { // eigen_assert(checkIndexRange(indices)); if (PlainObjectType::Options&RowMajor) { @@ -175,12 +179,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor template<typename... IndexTypes> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) { - static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); + static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); if (PlainObjectType::Options&RowMajor) { - const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}}); return m_data[index]; } else { - const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}}); + const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}}); return m_data[index]; } } @@ -247,8 +251,8 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor } private: - typename PlainObjectType::Scalar* m_data; - typename PlainObjectType::Dimensions m_dimensions; + Scalar* m_data; + Dimensions m_dimensions; }; } // end namespace Eigen |