diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBase.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 315 |
1 files changed, 244 insertions, 71 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index f451a3c99..8860f622b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -25,77 +25,118 @@ template<typename Derived> class TensorBase<Derived, ReadOnlyAccessors> { public: - typedef typename internal::traits<Derived>::Scalar Scalar; - typedef typename internal::traits<Derived>::Index Index; - typedef Scalar CoeffReturnType; - typedef typename internal::packet_traits<Scalar>::type PacketReturnType; + typedef internal::traits<Derived> DerivedTraits; + typedef typename DerivedTraits::Scalar Scalar; + typedef typename DerivedTraits::Index Index; + typedef typename internal::remove_const<Scalar>::type CoeffReturnType; + typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; + static const int NumDimensions = DerivedTraits::NumDimensions; - // Dimensions - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return derived().dimensions()[n]; } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Index size() const { return internal::array_prod(derived().dimensions()); } + // Generic nullary operation support. + template <typename CustomNullaryOp> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<CustomNullaryOp, const Derived> + nullaryExpr(const CustomNullaryOp& func) const { + return TensorCwiseNullaryOp<CustomNullaryOp, const Derived>(derived(), func); + } - // Nullary operators + // Coefficient-wise nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> constant(const Scalar& value) const { - return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> - (derived(), internal::scalar_constant_op<Scalar>(value)); + return nullaryExpr(internal::scalar_constant_op<Scalar>(value)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::UniformRandomGenerator<Scalar>, const Derived> random() const { - return TensorCwiseNullaryOp<internal::UniformRandomGenerator<Scalar>, const Derived>(derived()); + return nullaryExpr(internal::UniformRandomGenerator<Scalar>()); } template <typename RandomGenerator> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<RandomGenerator, const Derived> random() const { - return TensorCwiseNullaryOp<RandomGenerator, const Derived>(derived()); + return nullaryExpr(RandomGenerator()); + } + + // Generic unary operation support. + template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived> + unaryExpr(const CustomUnaryOp& func) const { + return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func); } // Coefficient-wise unary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> - operator-() const { return derived(); } + operator-() const { + return unaryExpr(internal::scalar_opposite_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> - sqrt() const { return derived(); } + sqrt() const { + return unaryExpr(internal::scalar_sqrt_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> - square() const { return derived(); } + square() const { + return unaryExpr(internal::scalar_square_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> - inverse() const { return derived(); } + inverse() const { + return unaryExpr(internal::scalar_inverse_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_exp_op<Scalar>, const Derived> - exp() const { return derived(); } + exp() const { + return unaryExpr(internal::scalar_exp_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_log_op<Scalar>, const Derived> - log() const { return derived(); } + log() const { + return unaryExpr(internal::scalar_log_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> - abs() const { return derived(); } + abs() const { + return unaryExpr(internal::scalar_abs_op<Scalar>()); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> pow(Scalar exponent) const { - return TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> - (derived(), internal::scalar_pow_op<Scalar>(exponent)); + return unaryExpr(internal::scalar_pow_op<Scalar>(exponent)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> + operator+ (Scalar rhs) const { + return unaryExpr(internal::scalar_add_op<Scalar>(rhs)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived> + operator- (Scalar rhs) const { + EIGEN_STATIC_ASSERT((std::numeric_limits<Scalar>::is_signed || internal::is_same<Scalar, const std::complex<float> >::value), YOU_MADE_A_PROGRAMMING_MISTAKE); + return unaryExpr(internal::scalar_sub_op<Scalar>(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived> - operator * (Scalar scale) const { - return TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived> - (derived(), internal::scalar_multiple_op<Scalar>(scale)); + operator* (Scalar rhs) const { + return unaryExpr(internal::scalar_multiple_op<Scalar>(rhs)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_quotient1_op<Scalar>, const Derived> + operator/ (Scalar rhs) const { + // EIGEN_STATIC_ASSERT(!std::numeric_limits<Scalar>::is_integer, YOU_MADE_A_PROGRAMMING_MISTAKE); + return unaryExpr(internal::scalar_quotient1_op<Scalar>(rhs)); } EIGEN_DEVICE_FUNC @@ -110,86 +151,106 @@ class TensorBase<Derived, ReadOnlyAccessors> return cwiseMin(constant(threshold)); } - template <typename CustomUnaryOp> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<CustomUnaryOp, const Derived> - unaryExpr(const CustomUnaryOp& func) const { - return TensorCwiseUnaryOp<CustomUnaryOp, const Derived>(derived(), func); - } - template <typename NewType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_cast_op<Scalar, NewType>, const Derived> cast() const { - return derived(); + return unaryExpr(internal::scalar_cast_op<Scalar, NewType>()); + } + + // Generic binary operation support. + template <typename CustomBinaryOp, typename OtherDerived> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived> + binaryExpr(const OtherDerived& other, const CustomBinaryOp& func) const { + return TensorCwiseBinaryOp<CustomBinaryOp, const Derived, const OtherDerived>(derived(), other, func); } // Coefficient-wise binary operators. template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived> operator+(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_sum_op<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived> operator-(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_difference_op<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived> operator*(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_product_op<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived> operator/(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { - return TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_boolean_and_op, const Derived, const OtherDerived> + operator&&(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_boolean_and_op()); + } + + template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_boolean_or_op, const Derived, const OtherDerived> + operator||(const OtherDerived& other) const { + return binaryExpr(other.derived(), internal::scalar_boolean_or_op()); } // Comparisons and tests. template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived> operator<(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::less<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived> operator<=(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::less_equal<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived> operator>(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::greater<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived> operator>=(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::greater_equal<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived> operator==(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::equal_to<Scalar>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived> operator!=(const OtherDerived& other) const { - return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return binaryExpr(other.derived(), std::not_equal_to<Scalar>()); + } + + // Coefficient-wise ternary operators. + template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived> + select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const { + return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived()); } // Contractions. @@ -208,29 +269,72 @@ class TensorBase<Derived, ReadOnlyAccessors> return TensorConvolutionOp<const Dimensions, const Derived, const KernelDerived>(derived(), kernel.derived(), dims); } - // Coefficient-wise ternary operators. - template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived> - select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const { - return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived()); - } - // Reductions. template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::SumReducer<Scalar>, const Dims, const Derived> + const TensorReductionOp<internal::SumReducer<CoeffReturnType>, const Dims, const Derived> sum(const Dims& dims) const { - return TensorReductionOp<internal::SumReducer<Scalar>, const Dims, const Derived>(derived(), dims, internal::SumReducer<Scalar>()); + return TensorReductionOp<internal::SumReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::SumReducer<CoeffReturnType>()); + } + + const TensorReductionOp<internal::SumReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + sum() const { + array<Index, NumDimensions> in_dims; + for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; + return TensorReductionOp<internal::SumReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::SumReducer<CoeffReturnType>()); + } + + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const Dims, const Derived> + mean(const Dims& dims) const { + return TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MeanReducer<CoeffReturnType>()); } + + const TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + mean() const { + array<Index, NumDimensions> in_dims; + for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; + return TensorReductionOp<internal::MeanReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MeanReducer<CoeffReturnType>()); + } + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::MaxReducer<Scalar>, const Dims, const Derived> + const TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const Dims, const Derived> + prod(const Dims& dims) const { + return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::ProdReducer<CoeffReturnType>()); + } + + const TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + prod() const { + array<Index, NumDimensions> in_dims; + for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; + return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>()); + } + + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived> maximum(const Dims& dims) const { - return TensorReductionOp<internal::MaxReducer<Scalar>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<Scalar>()); + return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>()); + } + + const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + maximum() const { + array<Index, NumDimensions> in_dims; + for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; + return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>()); } + template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp<internal::MinReducer<Scalar>, const Dims, const Derived> + const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived> minimum(const Dims& dims) const { - return TensorReductionOp<internal::MinReducer<Scalar>, const Dims, const Derived>(derived(), dims, internal::MinReducer<Scalar>()); + return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>()); } + + const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived> + minimum() const { + array<Index, NumDimensions> in_dims; + for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; + return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const array<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>()); + } + template <typename Reducer, typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp<Reducer, const Dims, const Derived> reduce(const Dims& dims, const Reducer& reducer) const { @@ -258,17 +362,44 @@ class TensorBase<Derived, ReadOnlyAccessors> template <Index Rows, Index Cols> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp<Rows, Cols, const Derived> extract_image_patches() const { - return TensorImagePatchOp<Rows, Cols, const Derived>(derived(), Rows, Cols, 1, 1); + return TensorImagePatchOp<Rows, Cols, const Derived>(derived(), Rows, Cols, 1, 1, PADDING_SAME); + } + + template <Index Rows, Index Cols> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorImagePatchOp<Rows, Cols, const Derived> + extract_image_patches(const PaddingType padding_type) const { + return TensorImagePatchOp<Rows, Cols, const Derived>(derived(), Rows, Cols, 1, 1, padding_type); + } + + template <Index Rows, Index Cols> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorImagePatchOp<Rows, Cols, const Derived> + extract_image_patches(const Index stride, const PaddingType padding_type) const { + return TensorImagePatchOp<Rows, Cols, const Derived>(derived(), Rows, Cols, stride, stride, padding_type); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp<Dynamic, Dynamic, const Derived> extract_image_patches(const Index patch_rows, const Index patch_cols, const Index row_stride = 1, const Index col_stride = 1) const { - return TensorImagePatchOp<Dynamic, Dynamic, const Derived>(derived(), patch_rows, patch_cols, row_stride, col_stride); + return TensorImagePatchOp<Dynamic, Dynamic, const Derived>(derived(), patch_rows, patch_cols, row_stride, col_stride, + PADDING_SAME); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorImagePatchOp<Dynamic, Dynamic, const Derived> + extract_image_patches(const Index patch_rows, const Index patch_cols, + const Index row_stride, const Index col_stride, + const PaddingType padding_type) const { + return TensorImagePatchOp<Dynamic, Dynamic, const Derived>(derived(), patch_rows, patch_cols, row_stride, col_stride, + padding_type); } // Morphing operators. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorLayoutSwapOp<const Derived> + swap_layout() const { + return TensorLayoutSwapOp<const Derived>(derived()); + } template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp<const NewDimensions, const Derived> reshape(const NewDimensions& newDimensions) const { @@ -279,10 +410,20 @@ class TensorBase<Derived, ReadOnlyAccessors> slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp<const StartIndices, const Sizes, const Derived>(derived(), startIndices, sizes); } - template <std::size_t DimId> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + template <Index DimId> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp<DimId, const Derived> chip(const Index offset) const { - return TensorChippingOp<DimId, const Derived>(derived(), offset); + return TensorChippingOp<DimId, const Derived>(derived(), offset, DimId); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorChippingOp<Dynamic, const Derived> + chip(const Index offset, const Index dim) const { + return TensorChippingOp<Dynamic, const Derived>(derived(), offset, dim); + } + template <typename ReverseDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReverseOp<const ReverseDimensions, const Derived> + reverse(const ReverseDimensions& rev) const { + return TensorReverseOp<const ReverseDimensions, const Derived>(derived(), rev); } template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp<const PaddingDimensions, const Derived> @@ -308,21 +449,24 @@ class TensorBase<Derived, ReadOnlyAccessors> protected: template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor; + template <typename Scalar, int Options> friend class TensorVarDim; template <typename OtherDerived, int AccessLevel> friend class TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); } }; - template<typename Derived> class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyAccessors> { public: - typedef typename internal::traits<Derived>::Scalar Scalar; - typedef typename internal::traits<Derived>::Index Index; + typedef internal::traits<Derived> DerivedTraits; + typedef typename DerivedTraits::Scalar Scalar; + typedef typename DerivedTraits::Index Index; typedef Scalar CoeffReturnType; typedef typename internal::packet_traits<Scalar>::type PacketReturnType; + static const int NumDimensions = DerivedTraits::NumDimensions; template <typename Scalar, std::size_t NumIndices, int Options> friend class Tensor; + template <typename Scalar, int Options> friend class TensorVarDim; template <typename OtherDerived, int AccessLevel> friend class TensorBase; EIGEN_DEVICE_FUNC @@ -337,24 +481,43 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->random(); } + template <typename RandomGenerator> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setRandom() { + return derived() = this->template random<RandomGenerator>(); + } + +#ifdef EIGEN_HAS_VARIADIC_TEMPLATES + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setValues( + const typename internal::Initializer<Derived, NumDimensions>::InitList& vals) { + TensorEvaluator<Derived, DefaultDevice> eval(derived(), DefaultDevice()); + internal::initialize_tensor<Derived, NumDimensions>(eval, vals); + return derived(); + } +#endif // EIGEN_HAS_VARIADIC_TEMPLATES template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const OtherDerived& other) { - return derived() = TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return derived() = derived() + other.derived(); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const OtherDerived& other) { - return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return derived() = derived() - other.derived(); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const OtherDerived& other) { - return derived() = TensorCwiseBinaryOp<internal::scalar_product_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return derived() = derived() * other.derived(); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const OtherDerived& other) { - return derived() = TensorCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); + return derived() = derived() / other.derived(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorLayoutSwapOp<Derived> + swap_layout() const { + return TensorLayoutSwapOp<Derived>(derived()); + } template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp<const NewDimensions, Derived> reshape(const NewDimensions& newDimensions) const { @@ -365,16 +528,26 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp<const StartIndices, const Sizes, Derived>(derived(), startIndices, sizes); } - template <std::size_t DimId> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + template <DenseIndex DimId> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp<DimId, Derived> chip(const Index offset) const { - return TensorChippingOp<DimId, Derived>(derived(), offset); + return TensorChippingOp<DimId, Derived>(derived(), offset, DimId); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorChippingOp<Dynamic, Derived> + chip(const Index offset, const Index dim) const { + return TensorChippingOp<Dynamic, Derived>(derived(), offset, dim); } template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp<const Shuffle, Derived> shuffle(const Shuffle& shuffle) const { return TensorShufflingOp<const Shuffle, Derived>(derived(), shuffle); } + template <typename Strides> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorStridingOp<const Strides, Derived> + stride(const Strides& strides) const { + return TensorStridingOp<const Strides, Derived>(derived(), strides); + } // Select the device on which to evaluate the expression. template <typename DeviceType> |