// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_BASE_H #define EIGEN_CXX11_TENSOR_TENSOR_BASE_H namespace Eigen { /** \class TensorBase * \ingroup CXX11_Tensor_Module * * \brief The tensor base class. * * This class is the common parent of the Tensor and TensorMap class, thus * making it possible to use either class interchangably in expressions. */ template class TensorBase { public: typedef typename internal::traits::Scalar Scalar; typedef typename internal::traits::Index Index; typedef Scalar CoeffReturnType; typedef typename internal::packet_traits::type PacketReturnType; // Dimensions EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return derived().dimensions()[n]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return internal::array_prod(derived().dimensions()); } // Nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> constant(const Scalar& value) const { return TensorCwiseNullaryOp, const Derived> (derived(), internal::scalar_constant_op(value)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> random() const { return TensorCwiseNullaryOp, const Derived>(derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp random() const { return TensorCwiseNullaryOp(derived()); } // Coefficient-wise unary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator-() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> sqrt() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> square() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> inverse() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> exp() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> log() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> abs() const { return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> pow(Scalar exponent) const { return TensorCwiseUnaryOp, const Derived> (derived(), internal::scalar_pow_op(exponent)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator * (Scalar scale) const { return TensorCwiseUnaryOp, const Derived> (derived(), internal::scalar_multiple_op(scale)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMax(Scalar threshold) const { return cwiseMax(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMin(Scalar threshold) const { return cwiseMin(constant(threshold)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp unaryExpr(const CustomUnaryOp& func) const { return TensorCwiseUnaryOp(derived(), func); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> cast() const { return derived(); } // Coefficient-wise binary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator+(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator-(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator*(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator/(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } // Comparisons and tests. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<=(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>=(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator==(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator!=(const OtherDerived& other) const { return TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } // Contractions. typedef Eigen::IndexPair DimensionPair; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorContractionOp contract(const OtherDerived& other, const Dimensions& dims) const { return TensorContractionOp(derived(), other.derived(), dims); } // Convolutions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConvolutionOp convolve(const KernelDerived& kernel, const Dimensions& dims) const { return TensorConvolutionOp(derived(), kernel.derived(), dims); } // Coefficient-wise ternary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSelectOp select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const { return TensorSelectOp(derived(), thenTensor.derived(), elseTensor.derived()); } // Reductions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> sum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::SumReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> maximum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> minimum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp reduce(const Dims& dims, const Reducer& reducer) const { return TensorReductionOp(derived(), dims, reducer); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorBroadcastingOp broadcast(const Broadcast& broadcast) const { return TensorBroadcastingOp(derived(), broadcast); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConcatenationOp concatenate(const OtherDerived& other, Axis axis) const { return TensorConcatenationOp(derived(), other.derived(), axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPatchOp extract_patches(const PatchDims& patch_dims) const { return TensorPatchOp(derived(), patch_dims); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches() const { return TensorImagePatchOp(derived(), Rows, Cols, 1, 1); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index patch_rows, const Index patch_cols, const Index row_stride = 1, const Index col_stride = 1) const { return TensorImagePatchOp(derived(), patch_rows, patch_cols, row_stride, col_stride); } // Morphing operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp pad(const PaddingDimensions& padding) const { return TensorPaddingOp(derived(), padding); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp shuffle(const Shuffle& shuffle) const { return TensorShufflingOp(derived(), shuffle); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } // Force the evaluation of the expression. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorForcedEvalOp eval() const { return TensorForcedEvalOp(derived()); } protected: template friend class Tensor; template friend class TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; template class TensorBase : public TensorBase { public: typedef typename internal::traits::Scalar Scalar; typedef typename internal::traits::Index Index; typedef Scalar CoeffReturnType; typedef typename internal::packet_traits::type PacketReturnType; template friend class Tensor; template friend class TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setZero() { return setConstant(Scalar(0)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { return derived() = this->constant(val); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->random(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const OtherDerived& other) { return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const OtherDerived& other) { return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const OtherDerived& other) { return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const OtherDerived& other) { return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp shuffle(const Shuffle& shuffle) const { return TensorShufflingOp(derived(), shuffle); } // Select the device on which to evaluate the expression. template TensorDevice device(const DeviceType& device) { return TensorDevice(device, derived()); } protected: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast(this); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H