// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_BASE_H #define EIGEN_CXX11_TENSOR_TENSOR_BASE_H // clang-format off namespace Eigen { /** \class TensorBase * \ingroup CXX11_Tensor_Module * * \brief The tensor base class. * * This class is the common parent of the Tensor and TensorMap class, thus * making it possible to use either class interchangeably in expressions. */ #ifndef EIGEN_PARSED_BY_DOXYGEN // FIXME Doxygen does not like the inheritance with different template parameters // Since there is no doxygen documentation inside, we disable it for now template class TensorBase { public: typedef internal::traits DerivedTraits; typedef typename DerivedTraits::Scalar Scalar; typedef typename DerivedTraits::Index Index; typedef typename internal::remove_const::type CoeffReturnType; static const int NumDimensions = DerivedTraits::NumDimensions; // Generic nullary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp nullaryExpr(const CustomNullaryOp& func) const { return TensorCwiseNullaryOp(derived(), func); } // Coefficient-wise nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> constant(const Scalar& value) const { return nullaryExpr(internal::scalar_constant_op(value)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> random() const { return nullaryExpr(internal::UniformRandomGenerator()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp random(const RandomGenerator& gen = RandomGenerator()) const { return nullaryExpr(gen); } // Tensor generation template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorGeneratorOp generate(const Generator& generator) const { return TensorGeneratorOp(derived(), generator); } // Generic unary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp unaryExpr(const CustomUnaryOp& func) const { return TensorCwiseUnaryOp(derived(), func); } // Coefficient-wise unary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator-() const { return unaryExpr(internal::scalar_opposite_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> sqrt() const { return unaryExpr(internal::scalar_sqrt_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> sign() const { return unaryExpr(internal::scalar_sign_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> rsqrt() const { return unaryExpr(internal::scalar_rsqrt_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> square() const { return unaryExpr(internal::scalar_square_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> cube() const { return unaryExpr(internal::scalar_cube_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> inverse() const { return unaryExpr(internal::scalar_inverse_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> tanh() const { return unaryExpr(internal::scalar_tanh_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> lgamma() const { return unaryExpr(internal::scalar_lgamma_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> digamma() const { return unaryExpr(internal::scalar_digamma_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_i0() const { return unaryExpr(internal::scalar_bessel_i0_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_i0e() const { return unaryExpr(internal::scalar_bessel_i0e_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_i1() const { return unaryExpr(internal::scalar_bessel_i1_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_i1e() const { return unaryExpr(internal::scalar_bessel_i1e_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_j0() const { return unaryExpr(internal::scalar_bessel_j0_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_y0() const { return unaryExpr(internal::scalar_bessel_y0_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_j1() const { return unaryExpr(internal::scalar_bessel_j1_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_y1() const { return unaryExpr(internal::scalar_bessel_y1_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_k0() const { return unaryExpr(internal::scalar_bessel_k0_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_k0e() const { return unaryExpr(internal::scalar_bessel_k0e_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_k1() const { return unaryExpr(internal::scalar_bessel_k1_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> bessel_k1e() const { return unaryExpr(internal::scalar_bessel_k1e_op()); } // igamma(a = this, x = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> igamma(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_igamma_op()); } // igamma_der_a(a = this, x = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> igamma_der_a(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_igamma_der_a_op()); } // gamma_sample_der_alpha(alpha = this, sample = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> gamma_sample_der_alpha(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_gamma_sample_der_alpha_op()); } // igammac(a = this, x = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> igammac(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_igammac_op()); } // zeta(x = this, q = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> zeta(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_zeta_op()); } // polygamma(n = this, x = other) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> polygamma(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_polygamma_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> erf() const { return unaryExpr(internal::scalar_erf_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> erfc() const { return unaryExpr(internal::scalar_erfc_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> ndtri() const { return unaryExpr(internal::scalar_ndtri_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> sigmoid() const { return unaryExpr(internal::scalar_logistic_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> exp() const { return unaryExpr(internal::scalar_exp_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> expm1() const { return unaryExpr(internal::scalar_expm1_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> log() const { return unaryExpr(internal::scalar_log_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> log1p() const { return unaryExpr(internal::scalar_log1p_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> log2() const { return unaryExpr(internal::scalar_log2_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> abs() const { return unaryExpr(internal::scalar_abs_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> clip(Scalar min, Scalar max) const { return unaryExpr(internal::scalar_clamp_op(min, max)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename internal::conditional::IsComplex, TensorCwiseUnaryOp, const Derived>, Derived>::type conjugate() const { return choose(Cond::IsComplex>(), unaryExpr(internal::scalar_conjugate_op()), derived()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> pow(Scalar exponent) const { return unaryExpr(internal::bind2nd_op >(exponent)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> real() const { return unaryExpr(internal::scalar_real_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> imag() const { return unaryExpr(internal::scalar_imag_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> operator+ (Scalar rhs) const { return unaryExpr(internal::bind2nd_op >(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend const TensorCwiseUnaryOp >, const Derived> operator+ (Scalar lhs, const Derived& rhs) { return rhs.unaryExpr(internal::bind1st_op >(lhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> operator- (Scalar rhs) const { EIGEN_STATIC_ASSERT((NumTraits::IsSigned || internal::is_same >::value), YOU_MADE_A_PROGRAMMING_MISTAKE); return unaryExpr(internal::bind2nd_op >(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend const TensorCwiseUnaryOp >, const Derived> operator- (Scalar lhs, const Derived& rhs) { return rhs.unaryExpr(internal::bind1st_op >(lhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> operator* (Scalar rhs) const { return unaryExpr(internal::bind2nd_op >(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend const TensorCwiseUnaryOp >, const Derived> operator* (Scalar lhs, const Derived& rhs) { return rhs.unaryExpr(internal::bind1st_op >(lhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp >, const Derived> operator/ (Scalar rhs) const { return unaryExpr(internal::bind2nd_op >(rhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend const TensorCwiseUnaryOp >, const Derived> operator/ (Scalar lhs, const Derived& rhs) { return rhs.unaryExpr(internal::bind1st_op >(lhs)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> operator% (Scalar rhs) const { EIGEN_STATIC_ASSERT(NumTraits::IsInteger, YOU_MADE_A_PROGRAMMING_MISTAKE_TRY_MOD); return unaryExpr(internal::scalar_mod_op(rhs)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMax(Scalar threshold) const { return cwiseMax(constant(threshold)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > cwiseMin(Scalar threshold) const { return cwiseMin(constant(threshold)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename internal::conditional::value, Derived, TensorConversionOp >::type cast() const { return choose(Cond::value>(), derived(), TensorConversionOp(derived())); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> round() const { return unaryExpr(internal::scalar_round_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> rint() const { return unaryExpr(internal::scalar_rint_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> ceil() const { return unaryExpr(internal::scalar_ceil_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> floor() const { return unaryExpr(internal::scalar_floor_op()); } // Generic binary operation support. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp binaryExpr(const OtherDerived& other, const CustomBinaryOp& func) const { return TensorCwiseBinaryOp(derived(), other, func); } // Coefficient-wise binary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator+(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_sum_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator-(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_difference_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator*(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_product_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator/(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_quotient_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_max_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_min_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp operator&&(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_boolean_and_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp operator||(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_boolean_or_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp operator^(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_boolean_xor_op()); } // Comparisons and tests. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<=(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator>=(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator==(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator!=(const OtherDerived& other) const { return binaryExpr(other.derived(), internal::scalar_cmp_op()); } // comparisons and tests for Scalars EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator<(Scalar threshold) const { return operator<(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator<=(Scalar threshold) const { return operator<=(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator>(Scalar threshold) const { return operator>(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator>=(Scalar threshold) const { return operator>=(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator==(Scalar threshold) const { return operator==(constant(threshold)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const TensorCwiseNullaryOp, const Derived> > operator!=(Scalar threshold) const { return operator!=(constant(threshold)); } // Checks EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> (isnan)() const { return unaryExpr(internal::scalar_isnan_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> (isinf)() const { return unaryExpr(internal::scalar_isinf_op()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> (isfinite)() const { return unaryExpr(internal::scalar_isfinite_op()); } // Coefficient-wise ternary operators. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSelectOp select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const { return TensorSelectOp(derived(), thenTensor.derived(), elseTensor.derived()); } // Contractions. typedef Eigen::IndexPair DimensionPair; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorContractionOp contract(const OtherDerived& other, const Dimensions& dims) const { return TensorContractionOp(derived(), other.derived(), dims); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorContractionOp contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const { return TensorContractionOp(derived(), other.derived(), dims, output_kernel); } // Convolutions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConvolutionOp convolve(const KernelDerived& kernel, const Dimensions& dims) const { return TensorConvolutionOp(derived(), kernel.derived(), dims); } // Fourier transforms template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorFFTOp fft(const FFT& dims) const { return TensorFFTOp(derived(), dims); } // Scan. typedef TensorScanOp, const Derived> TensorScanSumOp; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanSumOp cumsum(const Index& axis, bool exclusive = false) const { return TensorScanSumOp(derived(), axis, exclusive); } typedef TensorScanOp, const Derived> TensorScanProdOp; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanProdOp cumprod(const Index& axis, bool exclusive = false) const { return TensorScanProdOp(derived(), axis, exclusive); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanOp scan(const Index& axis, const Reducer& reducer, bool exclusive = false) const { return TensorScanOp(derived(), axis, exclusive, reducer); } // Reductions. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> sum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::SumReducer()); } const TensorReductionOp, const DimensionList, const Derived> sum() const { DimensionList in_dims; return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::SumReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> mean(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MeanReducer()); } const TensorReductionOp, const DimensionList, const Derived> mean() const { DimensionList in_dims; return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MeanReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> prod(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::ProdReducer()); } const TensorReductionOp, const DimensionList, const Derived> prod() const { DimensionList in_dims; return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::ProdReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> maximum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); } template const TensorReductionOp, const DimensionList, const Derived> maximum() const { DimensionList in_dims; return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MaxReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const Dims, const Derived> minimum(const Dims& dims) const { return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); } template const TensorReductionOp, const DimensionList, const Derived> minimum() const { DimensionList in_dims; return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MinReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp::value, Derived, TensorConversionOp >::type > all(const Dims& dims) const { return cast().reduce(dims, internal::AndReducer()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const typename internal::conditional::value, Derived, TensorConversionOp >::type > all() const { DimensionList in_dims; return cast().reduce(in_dims, internal::AndReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp::value, Derived, TensorConversionOp >::type > any(const Dims& dims) const { return cast().reduce(dims, internal::OrReducer()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp, const typename internal::conditional::value, Derived, TensorConversionOp >::type > any() const { DimensionList in_dims; return cast().reduce(in_dims, internal::OrReducer()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMaxTupleReducer >, const array, const Derived> argmax() const { array in_dims; for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d; return TensorTupleReducerOp< internal::ArgMaxTupleReducer >, const array, const Derived>(derived(), internal::ArgMaxTupleReducer >(), -1, in_dims); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMinTupleReducer >, const array, const Derived> argmin() const { array in_dims; for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d; return TensorTupleReducerOp< internal::ArgMinTupleReducer >, const array, const Derived>(derived(), internal::ArgMinTupleReducer >(), -1, in_dims); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMaxTupleReducer >, const array, const Derived> argmax(const Index return_dim) const { array in_dims; in_dims[0] = return_dim; return TensorTupleReducerOp< internal::ArgMaxTupleReducer >, const array, const Derived>(derived(), internal::ArgMaxTupleReducer >(), return_dim, in_dims); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTupleReducerOp< internal::ArgMinTupleReducer >, const array, const Derived> argmin(const Index return_dim) const { array in_dims; in_dims[0] = return_dim; return TensorTupleReducerOp< internal::ArgMinTupleReducer >, const array, const Derived>(derived(), internal::ArgMinTupleReducer >(), return_dim, in_dims); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReductionOp reduce(const Dims& dims, const Reducer& reducer) const { return TensorReductionOp(derived(), dims, reducer); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorTraceOp trace(const Dims& dims) const { return TensorTraceOp(derived(), dims); } const TensorTraceOp, const Derived> trace() const { DimensionList in_dims; return TensorTraceOp, const Derived>(derived(), in_dims); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorBroadcastingOp broadcast(const Broadcast& bcast) const { return TensorBroadcastingOp(derived(), bcast); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConcatenationOp concatenate(const OtherDerived& other, Axis axis) const { return TensorConcatenationOp(derived(), other.derived(), axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPatchOp extract_patches(const PatchDims& patch_dims) const { return TensorPatchOp(derived(), patch_dims); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index patch_rows = 1, const Index patch_cols = 1, const Index row_stride = 1, const Index col_stride = 1, const Index in_row_stride = 1, const Index in_col_stride = 1, const PaddingType padding_type = PADDING_SAME, const Scalar padding_value = Scalar(0)) const { return TensorImagePatchOp(derived(), patch_rows, patch_cols, row_stride, col_stride, in_row_stride, in_col_stride, 1, 1, padding_type, padding_value); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorImagePatchOp extract_image_patches(const Index patch_rows, const Index patch_cols, const Index row_stride, const Index col_stride, const Index in_row_stride, const Index in_col_stride, const Index row_inflate_stride, const Index col_inflate_stride, const Index padding_top, const Index padding_bottom, const Index padding_left,const Index padding_right, const Scalar padding_value) const { return TensorImagePatchOp(derived(), patch_rows, patch_cols, row_stride, col_stride, in_row_stride, in_col_stride, row_inflate_stride, col_inflate_stride, padding_top, padding_bottom, padding_left, padding_right, padding_value); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorVolumePatchOp extract_volume_patches(const Index patch_planes, const Index patch_rows, const Index patch_cols, const Index plane_stride = 1, const Index row_stride = 1, const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME, const Scalar padding_value = Scalar(0)) const { return TensorVolumePatchOp(derived(), patch_planes, patch_rows, patch_cols, plane_stride, row_stride, col_stride, 1, 1, 1, 1, 1, 1, padding_type, padding_value); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorVolumePatchOp extract_volume_patches(const Index patch_planes, const Index patch_rows, const Index patch_cols, const Index plane_stride, const Index row_stride, const Index col_stride, const Index plane_inflate_stride, const Index row_inflate_stride, const Index col_inflate_stride, const Index padding_top_z, const Index padding_bottom_z, const Index padding_top, const Index padding_bottom, const Index padding_left, const Index padding_right, const Scalar padding_value = Scalar(0)) const { return TensorVolumePatchOp(derived(), patch_planes, patch_rows, patch_cols, plane_stride, row_stride, col_stride, 1, 1, 1, plane_inflate_stride, row_inflate_stride, col_inflate_stride, padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right, padding_value); } // Morphing operators. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorLayoutSwapOp swap_layout() const { return TensorLayoutSwapOp(derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingSlicingOp stridedSlice(const StartIndices& startIndices, const StopIndices& stopIndices, const Strides& strides) const { return TensorStridingSlicingOp(derived(), startIndices, stopIndices, strides); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset, DimId); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset, const Index dim) const { return TensorChippingOp(derived(), offset, dim); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReverseOp reverse(const ReverseDimensions& rev) const { return TensorReverseOp(derived(), rev); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp pad(const PaddingDimensions& padding) const { return TensorPaddingOp(derived(), padding, internal::scalar_cast_op()(0)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorPaddingOp pad(const PaddingDimensions& padding, const Scalar padding_value) const { return TensorPaddingOp(derived(), padding, padding_value); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp shuffle(const Shuffle& shfl) const { return TensorShufflingOp(derived(), shfl); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorInflationOp inflate(const Strides& strides) const { return TensorInflationOp(derived(), strides); } // Returns a tensor containing index/value tuples EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorIndexTupleOp index_tuples() const { return TensorIndexTupleOp(derived()); } // Support for custom unary and binary operations template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCustomUnaryOp customOp(const CustomUnaryFunc& op) const { return TensorCustomUnaryOp(derived(), op); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCustomBinaryOp customOp(const OtherDerived& other, const CustomBinaryFunc& op) const { return TensorCustomBinaryOp(derived(), other, op); } // Force the evaluation of the expression. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorForcedEvalOp eval() const { return TensorForcedEvalOp(derived()); } protected: template friend class Tensor; template friend class TensorFixedSize; // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0 template friend class Eigen::TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; template::value> class TensorBase : public TensorBase { public: typedef TensorBase Base; typedef internal::traits DerivedTraits; typedef typename DerivedTraits::Scalar Scalar; typedef typename DerivedTraits::Index Index; typedef Scalar CoeffReturnType; static const int NumDimensions = DerivedTraits::NumDimensions; template friend class Tensor; template friend class TensorFixedSize; // the Eigen:: prefix is required to workaround a compilation issue with nvcc 9.0 template friend class Eigen::TensorBase; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setZero() { return setConstant(Scalar(0)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { return derived() = this->constant(val); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->random(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setRandom() { return derived() = this->template random(); } #if EIGEN_HAS_VARIADIC_TEMPLATES EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& setValues( const typename internal::Initializer::InitList& vals) { TensorEvaluator eval(derived(), DefaultDevice()); internal::initialize_tensor(eval, vals); return derived(); } #endif // EIGEN_HAS_VARIADIC_TEMPLATES template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const OtherDerived& other) { return derived() = derived() + other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const OtherDerived& other) { return derived() = derived() - other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const OtherDerived& other) { return derived() = derived() * other.derived(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const OtherDerived& other) { return derived() = derived() / other.derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorLayoutSwapOp swap_layout() const { return TensorLayoutSwapOp(derived()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorLayoutSwapOp swap_layout() { return TensorLayoutSwapOp(derived()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConcatenationOp concatenate(const OtherDerived& other, const Axis& axis) const { return TensorConcatenationOp(derived(), other, axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp concatenate(const OtherDerived& other, const Axis& axis) { return TensorConcatenationOp(derived(), other, axis); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp reshape(const NewDimensions& newDimensions) const { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp reshape(const NewDimensions& newDimensions) { return TensorReshapingOp(derived(), newDimensions); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) const { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorSlicingOp slice(const StartIndices& startIndices, const Sizes& sizes) { return TensorSlicingOp(derived(), startIndices, sizes); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingSlicingOp stridedSlice(const StartIndices& startIndices, const StopIndices& stopIndices, const Strides& strides) const { return TensorStridingSlicingOp(derived(), startIndices, stopIndices, strides); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorStridingSlicingOp stridedSlice(const StartIndices& startIndices, const StopIndices& stopIndices, const Strides& strides) { return TensorStridingSlicingOp(derived(), startIndices, stopIndices, strides); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset) const { return TensorChippingOp(derived(), offset, DimId); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp chip(const Index offset) { return TensorChippingOp(derived(), offset, DimId); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp chip(const Index offset, const Index dim) const { return TensorChippingOp(derived(), offset, dim); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorChippingOp chip(const Index offset, const Index dim) { return TensorChippingOp(derived(), offset, dim); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReverseOp reverse(const ReverseDimensions& rev) const { return TensorReverseOp(derived(), rev); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReverseOp reverse(const ReverseDimensions& rev) { return TensorReverseOp(derived(), rev); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorShufflingOp shuffle(const Shuffle& shfl) const { return TensorShufflingOp(derived(), shfl); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorShufflingOp shuffle(const Shuffle& shfl) { return TensorShufflingOp(derived(), shfl); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorStridingOp stride(const Strides& strides) { return TensorStridingOp(derived(), strides); } // Select the device on which to evaluate the expression. template TensorDevice device(const DeviceType& dev) { return TensorDevice(dev, derived()); } // Select the async device on which to evaluate the expression. template TensorAsyncDevice device(const DeviceType& dev, DoneCallback done) { return TensorAsyncDevice(dev, derived(), std::move(done)); } protected: EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(TensorBase) EIGEN_DEFAULT_COPY_CONSTRUCTOR(TensorBase) template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator=(const OtherDerived& other) { typedef TensorAssignOp Assign; Assign assign(derived(), other.derived()); internal::TensorExecutor::run(assign, DefaultDevice()); return derived(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast(this); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } }; #endif // EIGEN_PARSED_BY_DOXYGEN } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H