// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_BASE_H #define EIGEN_CXX11_TENSOR_TENSOR_BASE_H namespace Eigen { /** \class TensorBase * \ingroup CXX11_Tensor_Module * * \brief The tensor base class. * * This class is the common parent of the Tensor and TensorMap class, thus * making it possible to use either class interchangably in expressions. */ template class TensorBase { public: typedef internal::traits DerivedTraits; typedef typename DerivedTraits::Scalar Scalar; typedef typename DerivedTraits::Index Index; typedef typename internal::remove_const::type CoeffReturnType; typedef typename internal::packet_traits::type PacketReturnType; static const int NumDimensions = DerivedTraits::NumDimensions; // Generic nullary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp nullaryExpr(const CustomNullaryOp& func) const { return TensorCwiseNullaryOp(derived(), func); } // Coefficient-wise nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> constant(const Scalar& value) const { return nullaryExpr(internal::scalar_constant_op(value)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> random() const { return nullaryExpr(internal::UniformRandomGenerator()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp random() const { return nullaryExpr(RandomGenerator()); } // Generic unary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp unaryExpr(const CustomUnaryOp& func) const { return TensorCwiseUnaryOp(derived(), func); } // Coefficient-wise unary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator-() const { return unaryExpr(internal::scalar_opposite_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> sqrt() const { return unaryExpr(internal::scalar_sqrt_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> square() const { return unaryExpr(internal::scalar_square_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> cube() const { return unaryExpr(internal::scalar_cube_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> inverse() const { return unaryExpr(internal::scalar_inverse_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> exp() const { return unaryExpr(internal::scalar_exp_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> log() const { return unaryExpr(internal::scalar_log_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> abs() const { return unaryExpr(internal::scalar_abs_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> pow(Scalar exponent) const { return unaryExpr(internal::scalar_pow_op(exponent)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator+ (Scalar rhs) const { return unaryExpr(internal::scalar_add_op(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator- (Scalar rhs) const { EIGEN_STATIC_ASSERT((std::numeric_limits::is_signed || internal::is_same >::value), YOU_MADE_A_PROGRAMMING_MISTAKE); return unaryExpr(internal::scalar_sub_op(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator* (Scalar rhs) const { return unaryExpr(internal::scalar_multiple_op(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator/ (Scalar rhs) const { // EIGEN_STATIC_ASSERT(!std::numeric_limits::is_integer, YOU_MADE_A_PROGRAMMING_MISTAKE); return unaryExpr(internal::scalar_quotient1_op(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMax(Scalar threshold) const { return cwiseMax(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMin(Scalar threshold) const { return cwiseMin(constant(threshold)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> cast() const { return unaryExpr(internal::scalar_cast_op()); } // Generic binary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp binaryExpr(const OtherDerived& other, const CustomBinaryOp& func) const { return TensorCwiseBinaryOp(derived(), other, func); } // Coefficient-wise binary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator+(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_sum_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator-(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_difference_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator*(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_product_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator/(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_quotient_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_max_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_min_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp operator&&(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_boolean_and_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp operator||(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_boolean_or_op()); } // Comparisons and tests. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<(const OtherDerived& other) const { return binaryExpr(other.derived(), std::less()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<=(const OtherDerived& other) const { return binaryExpr(other.derived(), std::less_equal()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>(const OtherDerived& other) const { return binaryExpr(other.derived(), std::greater()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>=(const OtherDerived& other) const { return binaryExpr(other.derived(), std::greater_equal()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator==(const OtherDerived& other) const { return binaryExpr(other.derived(), std::equal_to()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator!=(const OtherDerived& other) const { return binaryExpr(other.derived(), std::not_equal_to()); } // Coefficient-wise ternary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSelectOp select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const { return TensorSelectOp(derived(), thenTensor.derived(), elseTensor.derived()); } // Contractions. typedef Eigen::IndexPair DimensionPair; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorContractionOp contract(const OtherDerived& other, const Dimensions& dims) const { return TensorContractionOp(derived(), other.derived(), dims); } // Convolutions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConvolutionOp convolve(const KernelDerived& kernel, const Dimensions& dims) const { return TensorConvolutionOp(derived(), kernel.derived(), dims); } // Reductions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> sum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::SumReducer()); } const TensorReductionOp, const array, const Derived> sum() const { array in_dims; for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::SumReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> mean(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MeanReducer()); } const TensorReductionOp, const array, const Derived> mean() const { array in_dims; for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MeanReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> prod(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::ProdReducer()); } const TensorReductionOp, const array, const Derived> prod() const { array in_dims; for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::ProdReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> maximum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); } const TensorReductionOp, const array, const Derived> maximum() const { array in_dims; for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MaxReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> minimum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); } const TensorReductionOp, const array, const Derived> minimum() const { array in_dims; for (int i = 0; i < NumDimensions; ++i) in_dims[i] = i; return TensorReductionOp, const array, const Derived>(derived(), in_dims, internal::MinReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp reduce(const Dims& dims, const Reducer& reducer) const { return TensorReductionOp(derived(), dims, reducer); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorBroadcastingOp broadcast(const Broadcast& broadcast) const { return TensorBroadcastingOp(derived(), broadcast); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConcatenationOp concatenate(const OtherDerived& other, Axis axis) const { return TensorConcatenationOp(derived(), other.derived(), axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPatchOp extract_patches(const PatchDims& patch_dims) const { return TensorPatchOp(derived(), patch_dims); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches() const { return TensorImagePatchOp(derived(), Rows, Cols, 1, 1, PADDING_SAME); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const PaddingType padding_type) const { return TensorImagePatchOp(derived(), Rows, Cols, 1, 1, padding_type); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index stride, const PaddingType padding_type) const { return TensorImagePatchOp(derived(), Rows, Cols, stride, stride, padding_type); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index patch_rows, const Index patch_cols, const Index row_stride = 1, const Index col_stride = 1) const { return TensorImagePatchOp(derived(), patch_rows, patch_cols, row_stride, col_stride, PADDING_SAME); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index patch_rows, const Index patch_cols, const Index row_stride, const Index col_stride, const PaddingType padding_type) const { return TensorImagePatchOp(derived(), patch_rows, patch_cols, row_stride, col_stride, padding_type); } // Morphing operators. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorLayoutSwapOp swap_layout() const { return TensorLayoutSwapOp(derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset, DimId); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset, const Index dim) const { return TensorChippingOp(derived(), offset, dim); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReverseOp reverse(const ReverseDimensions& rev) const { return TensorReverseOp(derived(), rev); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp pad(const PaddingDimensions& padding) const { return TensorPaddingOp(derived(), padding); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp shuffle(const Shuffle& shuffle) const { return TensorShufflingOp(derived(), shuffle); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } // Force the evaluation of the expression. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorForcedEvalOp eval() const { return TensorForcedEvalOp(derived()); } protected: template friend class Tensor; template friend class TensorVarDim; template friend class TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; template class TensorBase : public TensorBase { public: typedef internal::traits DerivedTraits; typedef typename DerivedTraits::Scalar Scalar; typedef typename DerivedTraits::Index Index; typedef Scalar CoeffReturnType; typedef typename internal::packet_traits::type PacketReturnType; static const int NumDimensions = DerivedTraits::NumDimensions; template friend class Tensor; template friend class TensorVarDim; template friend class TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setZero() { return setConstant(Scalar(0)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { return derived() = this->constant(val); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->random(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->template random(); } #ifdef EIGEN_HAS_VARIADIC_TEMPLATES EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setValues( const typename internal::Initializer::InitList& vals) { TensorEvaluator eval(derived(), DefaultDevice()); internal::initialize_tensor(eval, vals); return derived(); } #endif // EIGEN_HAS_VARIADIC_TEMPLATES template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const OtherDerived& other) { return derived() = derived() + other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const OtherDerived& other) { return derived() = derived() - other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const OtherDerived& other) { return derived() = derived() * other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const OtherDerived& other) { return derived() = derived() / other.derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorLayoutSwapOp swap_layout() const { return TensorLayoutSwapOp(derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp concatenate(const OtherDerived& other, const Axis& axis) const { return TensorConcatenationOp(derived(), other.derived(), axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset, DimId); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp chip(const Index offset, const Index dim) const { return TensorChippingOp(derived(), offset, dim); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp shuffle(const Shuffle& shuffle) const { return TensorShufflingOp(derived(), shuffle); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorStridingOp stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } // Select the device on which to evaluate the expression. template TensorDevice device(const DeviceType& device) { return TensorDevice(device, derived()); } protected: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast(this); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H