aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-22 16:22:35 -0700
commit736267cf6b17832a571acf7e34ca07c7f55907ee (patch)
tree894d0bfd7455b670117a252afad0157ba01a766b
parent7402fea0a8e63e3ea248257047c584afee8f8bde (diff)
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...) * selection * nullary ops such as random or constant generation * misc unary ops such as log(), exp(), or a user defined unaryExpr() Cleaned up the code a little.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h139
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h84
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h109
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h2
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMap.h36
5 files changed, 339 insertions, 31 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index fa1bd3498..8a88ba806 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -33,21 +33,25 @@ class TensorBase
Derived& setZero() {
return setConstant(Scalar(0));
}
-
Derived& setConstant(const Scalar& val) {
- Scalar* data = derived().data();
- for (int i = 0; i < derived().size(); ++i) {
- data[i] = val;
- }
- return derived();
+ return derived() = constant(val);
}
-
Derived& setRandom() {
- Scalar* data = derived().data();
- for (int i = 0; i < derived().size(); ++i) {
- data[i] = internal::random_default_impl<Scalar, false, false>::run();
- }
- return derived();
+ return derived() = random();
+ }
+
+ // Nullary operators
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
+ constant(const Scalar& value) const {
+ return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
+ (internal::scalar_constant_op<Scalar>(value));
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>
+ random() const {
+ return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>();
}
// Coefficient-wise unary operators
@@ -57,15 +61,31 @@ class TensorBase
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived>
- cwiseSqrt() const { return derived(); }
+ sqrt() const { return derived(); }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived>
+ square() const { return derived(); }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived>
+ inverse() const { return derived(); }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived>
+ exp() const { return derived(); }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived>
+ log() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived>
- cwiseAbs() const { return derived(); }
+ abs() const { return derived(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
- cwisePow(Scalar exponent) const {
+ pow(Scalar exponent) const {
return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
(derived(), internal::scalar_pow_op<Scalar>(exponent));
}
@@ -77,6 +97,30 @@ class TensorBase
(derived(), internal::scalar_multiple_op<Scalar>(scale));
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ cwiseMax(Scalar threshold) const {
+ return cwiseMax(constant(threshold));
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ cwiseMin(Scalar threshold) const {
+ return cwiseMin(constant(threshold));
+ }
+
+ template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived>
+ unaryExpr(const CustomUnaryOp& func) const {
+ return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func);
+ }
+
+ template <typename NewType> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived>
+ cast() const {
+ return derived();
+ }
+
// Coefficient-wise binary operators.
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>
@@ -90,6 +134,71 @@ class TensorBase
return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>
+ operator*(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>
+ operator/(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
+ cwiseMax(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
+ cwiseMin(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ // Comparisons and tests.
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>
+ operator<(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>
+ operator<=(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>
+ operator>(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>
+ operator>=(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
+ operator==(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
+ operator!=(const OtherDerived& other) const {
+ return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ // Coefficient-wise ternary operators.
+ template<typename ThenDerived,typename ElseDerived>
+ inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
+ select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{
+ return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
+ }
+
+ // Select the device on which to evaluate the expression.
template <typename DeviceType>
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
return TensorDevice<Derived, DeviceType>(device, derived());
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 3ce924dc3..e0c0863b7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -68,6 +68,42 @@ struct TensorEvaluator
+// -------------------- CwiseNullaryOp --------------------
+
+template<typename NullaryOp, typename PlainObjectType>
+struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+{
+ typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType;
+
+ enum {
+ IsAligned = true,
+ PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
+ };
+
+ TensorEvaluator(const XprType& op)
+ : m_functor(op.functor())
+ { }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_functor(index);
+ }
+
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
+ {
+ return m_functor.packetOp(index);
+ }
+
+ private:
+ const NullaryOp m_functor;
+};
+
+
// -------------------- CwiseUnaryOp --------------------
@@ -146,6 +182,54 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
TensorEvaluator<RightArgType> m_rightImpl;
};
+
+// -------------------- SelectOp --------------------
+
+template<typename IfArgType, typename ThenArgType, typename ElseArgType>
+struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
+{
+ typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
+
+ enum {
+ IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
+ PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
+ TensorEvaluator<IfArgType>::PacketAccess*/,
+ };
+
+ TensorEvaluator(const XprType& op)
+ : m_condImpl(op.ifExpression()),
+ m_thenImpl(op.thenExpression()),
+ m_elseImpl(op.elseExpression())
+ { }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
+ {
+ return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
+ }
+ template<int LoadMode>
+ EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
+ {
+ static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
+ internal::Selector<PacketSize> select;
+ for (Index i = 0; i < PacketSize; ++i) {
+ select.select[i] = m_condImpl.coeff(index+i);
+ }
+ return internal::pblend(select,
+ m_thenImpl.template packet<LoadMode>(index),
+ m_elseImpl.template packet<LoadMode>(index));
+ }
+
+ private:
+ TensorEvaluator<IfArgType> m_condImpl;
+ TensorEvaluator<ThenArgType> m_thenImpl;
+ TensorEvaluator<ElseArgType> m_elseImpl;
+};
+
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index e32077f6e..94cfae05c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -17,6 +17,9 @@ namespace Eigen {
*
* \brief Tensor expression classes.
*
+ * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This
+ * is typically used to generate constants.
+ *
* The TensorCwiseUnaryOp class represents an expression where a unary operator
* (e.g. cwiseSqrt) is applied to an expression.
*
@@ -24,6 +27,46 @@ namespace Eigen {
* (e.g. addition) is applied to a lhs and a rhs expression.
*
*/
+namespace internal {
+template<typename NullaryOp, typename PlainObjectType>
+struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+ : traits<PlainObjectType>
+{
+ typedef typename PlainObjectType::Packet Packet;
+ typedef typename PlainObjectType::Scalar Scalar;
+ typedef typename PlainObjectType::Nested XprTypeNested;
+ typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
+};
+
+} // end namespace internal
+
+
+
+template<typename NullaryOp, typename PlainObjectType>
+class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
+ typedef typename PlainObjectType::PacketReturnType PacketReturnType;
+ typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp())
+ : m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const NullaryOp& functor() const { return m_functor; }
+
+ protected:
+ // todo: add tensor dimension to be able to do some sanity checks
+ const NullaryOp m_functor;
+};
+
+
namespace internal {
template<typename UnaryOp, typename XprType>
@@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
const BinaryOp m_functor;
};
+
+namespace internal {
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
+ : traits<ThenXprType>
+{
+ typedef typename traits<ThenXprType>::Scalar Scalar;
+ typedef typename internal::packet_traits<Scalar>::type Packet;
+ typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
+ typename traits<ElseXprType>::StorageKind>::ret StorageKind;
+ typedef typename promote_index_type<typename traits<ElseXprType>::Index,
+ typename traits<ThenXprType>::Index>::type Index;
+ typedef typename IfXprType::Nested IfNested;
+ typedef typename ThenXprType::Nested ThenNested;
+ typedef typename ElseXprType::Nested ElseNested;
+};
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
+{
+ typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
+};
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
+{
+ typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
+};
+
+} // end namespace internal
+
+
+template<typename IfXprType, typename ThenXprType, typename ElseXprType>
+class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
+{
+ public:
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet;
+ typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
+ typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
+ typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
+ typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType,
+ typename ElseXprType::PacketReturnType>::ret PacketReturnType;
+ typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
+ typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
+
+ TensorSelectOp(const IfXprType& a_condition,
+ const ThenXprType& a_then,
+ const ElseXprType& a_else)
+ : m_condition(a_condition), m_then(a_then), m_else(a_else)
+ { }
+
+ const IfXprType& ifExpression() const { return m_condition; }
+
+ const ThenXprType& thenExpression() const { return m_then; }
+
+ const ElseXprType& elseExpression() const { return m_else; }
+
+ protected:
+ typename IfXprType::Nested m_condition;
+ typename ThenXprType::Nested m_then;
+ typename ElseXprType::Nested m_else;
+};
+
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index 09b0fe66d..03ac8d516 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -17,8 +17,10 @@ template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFi
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
template<typename Derived> class TensorBase;
+template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
+template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
index 3fc9c5335..3a2ff5b30 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
@@ -45,33 +45,37 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
static const int Options = Options_;
+ static const std::size_t NumIndices = PlainObjectType::NumIndices;
+ typedef typename PlainObjectType::Dimensions Dimensions;
+
+
enum {
IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned),
PacketAccess = true,
};
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) {
+ EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>(firstDimension)) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
- EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ EIGEN_STATIC_ASSERT(1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension, otherDimensions...}})) {
+ EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(array<DenseIndex, NumIndices>({{firstDimension, otherDimensions...}})) {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
- EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
}
#endif
- inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions)
+ inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const typename PlainObjectType::Dimensions& dimensions() const { return m_dimensions; }
+ EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
EIGEN_DEVICE_FUNC
@@ -80,7 +84,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) const
+ EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
{
// eigen_assert(checkIndexRange(indices));
if (PlainObjectType::Options&RowMajor) {
@@ -96,12 +100,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
{
- static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
+ static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
if (PlainObjectType::Options&RowMajor) {
- const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index];
} else {
- const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index];
}
}
@@ -159,7 +163,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
#endif
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices)
+ EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
{
// eigen_assert(checkIndexRange(indices));
if (PlainObjectType::Options&RowMajor) {
@@ -175,12 +179,12 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
{
- static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
+ static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
if (PlainObjectType::Options&RowMajor) {
- const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index];
} else {
- const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
return m_data[index];
}
}
@@ -247,8 +251,8 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
}
private:
- typename PlainObjectType::Scalar* m_data;
- typename PlainObjectType::Dimensions m_dimensions;
+ Scalar* m_data;
+ Dimensions m_dimensions;
};
} // end namespace Eigen