// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H namespace Eigen { /** \internal * * \class TensorDimensions * \ingroup CXX11_Tensor_Module * * \brief Set of classes used to encode and store the dimensions of a Tensor. * * The Sizes class encodes as part of the type the number of dimensions and the * sizes corresponding to each dimension. It uses no storage space since it is * entirely known at compile time. * The DSizes class is its dynamic sibling: the number of dimensions is known * at compile time but the sizes are set during execution. * * \sa Tensor */ // Boilerplate code namespace internal { template struct dget { static const std::ptrdiff_t value = get::value; }; template struct fixed_size_tensor_index_linearization_helper { template EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(array const& indices, const Dimensions& dimensions) { return array_get(indices) + dget::value * fixed_size_tensor_index_linearization_helper::run(indices, dimensions); } }; template struct fixed_size_tensor_index_linearization_helper { template EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(array const&, const Dimensions&) { return 0; } }; template struct fixed_size_tensor_index_extraction_helper { template EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(const Index index, const Dimensions& dimensions) { const Index mult = (index == n-1) ? 1 : 0; return array_get(dimensions) * mult + fixed_size_tensor_index_extraction_helper::run(index, dimensions); } }; template struct fixed_size_tensor_index_extraction_helper { template EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(const Index, const Dimensions&) { return 0; } }; } // end namespace internal // Fixed size #ifndef EIGEN_EMULATE_CXX11_META_H template struct Sizes { typedef internal::numeric_list Base; const Base t = Base(); static const std::ptrdiff_t total_size = internal::arg_prod(Indices...); static const ptrdiff_t count = Base::count; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const { return Base::count; } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() { return internal::arg_prod(Indices...); } EIGEN_DEVICE_FUNC Sizes() { } template explicit EIGEN_DEVICE_FUNC Sizes(const array& /*indices*/) { // todo: add assertion } #if EIGEN_HAS_VARIADIC_TEMPLATES template EIGEN_DEVICE_FUNC Sizes(DenseIndex...) { } explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list /*l*/) { // todo: add assertion } #endif template Sizes& operator = (const T& /*other*/) { // add assertion failure if the size of other is different return *this; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::ptrdiff_t index) const { return internal::fixed_size_tensor_index_extraction_helper::run(index, t); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfColMajor(const array& indices) const { return internal::fixed_size_tensor_index_linearization_helper::run(indices, t); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfRowMajor(const array& indices) const { return internal::fixed_size_tensor_index_linearization_helper::run(indices, t); } }; namespace internal { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes&) { return Sizes::total_size; } } #else template struct non_zero_size { typedef internal::type2val type; }; template <> struct non_zero_size<0> { typedef internal::null_type type; }; template struct Sizes { typedef typename internal::make_type_list::type, typename non_zero_size::type, typename non_zero_size::type, typename non_zero_size::type, typename non_zero_size::type >::type Base; static const std::ptrdiff_t count = Base::count; static const std::ptrdiff_t total_size = internal::arg_prod::value; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t rank() const { return count; } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t TotalSize() { return internal::arg_prod::value; } Sizes() { } template explicit Sizes(const array& /*indices*/) { // todo: add assertion } template Sizes& operator = (const T& /*other*/) { // add assertion failure if the size of other is different return *this; } #if EIGEN_HAS_VARIADIC_TEMPLATES template Sizes(DenseIndex... /*indices*/) { } explicit Sizes(std::initializer_list) { // todo: add assertion } #else EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) { } EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex) { } EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex) { } EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) { } EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) { } #endif EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index operator[] (const Index index) const { switch (index) { case 0: return internal::get<0, Base>::value; case 1: return internal::get<1, Base>::value; case 2: return internal::get<2, Base>::value; case 3: return internal::get<3, Base>::value; case 4: return internal::get<4, Base>::value; default: eigen_assert(false && "index overflow"); return static_cast(-1); } } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfColMajor(const array& indices) const { return internal::fixed_size_tensor_index_linearization_helper::run(indices, *reinterpret_cast(this)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfRowMajor(const array& indices) const { return internal::fixed_size_tensor_index_linearization_helper::run(indices, *reinterpret_cast(this)); } }; namespace internal { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes&) { return Sizes::total_size; } } #endif // Boilerplate namespace internal { template struct tensor_index_linearization_helper { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array const& indices, array const& dimensions) { return array_get(indices) + array_get(dimensions) * tensor_index_linearization_helper::run(indices, dimensions); } }; template struct tensor_index_linearization_helper { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array const& indices, array const&) { return array_get(indices); } }; } // end namespace internal // Dynamic size template struct DSizes : array { typedef array Base; static const int count = NumDims; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumDims; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const { return (NumDims == 0) ? 1 : internal::array_prod(*static_cast(this)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() { for (int i = 0 ; i < NumDims; ++i) { (*this)[i] = 0; } } EIGEN_DEVICE_FUNC explicit DSizes(const array& a) : Base(a) { } EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) { eigen_assert(NumDims == 1); (*this)[0] = i0; } EIGEN_DEVICE_FUNC DSizes(const DimensionList& a) { for (int i = 0 ; i < NumDims; ++i) { (*this)[i] = a[i]; } } // Enable DSizes index type promotion only if we are promoting to the // larger type, e.g. allow to promote dimensions of type int to long. template EIGEN_DEVICE_FUNC explicit DSizes(const array& other, // Default template parameters require c++11. typename internal::enable_if< internal::is_same< DenseIndex, typename internal::promote_index_type< DenseIndex, OtherIndex >::type >::value, void*>::type = 0) { for (int i = 0; i < NumDims; ++i) { (*this)[i] = static_cast(other[i]); } } #ifdef EIGEN_HAS_INDEX_LIST template EIGEN_DEVICE_FUNC explicit DSizes(const Eigen::IndexList& dimensions) { for (int i = 0; i < dimensions.count; ++i) { (*this)[i] = dimensions[i]; } } #endif #ifndef EIGEN_EMULATE_CXX11_META_H template EIGEN_DEVICE_FUNC DSizes(const Sizes& a) { for (int i = 0 ; i < NumDims; ++i) { (*this)[i] = a[i]; } } #else template EIGEN_DEVICE_FUNC DSizes(const Sizes& a) { for (int i = 0 ; i < NumDims; ++i) { (*this)[i] = a[i]; } } #endif #if EIGEN_HAS_VARIADIC_TEMPLATES template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension, IndexTypes... otherDimensions) : Base({{firstDimension, secondDimension, otherDimensions...}}) { EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE) } #else EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1) { eigen_assert(NumDims == 2); (*this)[0] = i0; (*this)[1] = i1; } EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) { eigen_assert(NumDims == 3); (*this)[0] = i0; (*this)[1] = i1; (*this)[2] = i2; } EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) { eigen_assert(NumDims == 4); (*this)[0] = i0; (*this)[1] = i1; (*this)[2] = i2; (*this)[3] = i3; } EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) { eigen_assert(NumDims == 5); (*this)[0] = i0; (*this)[1] = i1; (*this)[2] = i2; (*this)[3] = i3; (*this)[4] = i4; } #endif EIGEN_DEVICE_FUNC DSizes& operator = (const array& other) { *static_cast(this) = other; return *this; } // A constexpr would be so much better here EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array& indices) const { return internal::tensor_index_linearization_helper::run(indices, *static_cast(this)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array& indices) const { return internal::tensor_index_linearization_helper::run(indices, *static_cast(this)); } }; template std::ostream& operator<<(std::ostream& os, const DSizes& dims) { os << "["; for (int i = 0; i < NumDims; ++i) { if (i > 0) os << ", "; os << dims[i]; } os << "]"; return os; } // Boilerplate namespace internal { template struct tensor_vsize_index_linearization_helper { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array const& indices, std::vector const& dimensions) { return array_get(indices) + array_get(dimensions) * tensor_vsize_index_linearization_helper::run(indices, dimensions); } }; template struct tensor_vsize_index_linearization_helper { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array const& indices, std::vector const&) { return array_get(indices); } }; } // end namespace internal namespace internal { template struct array_size > { static const ptrdiff_t value = NumDims; }; template struct array_size > { static const ptrdiff_t value = NumDims; }; #ifndef EIGEN_EMULATE_CXX11_META_H template struct array_size > { static const std::ptrdiff_t value = Sizes::count; }; template struct array_size > { static const std::ptrdiff_t value = Sizes::count; }; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes&) { return get >::value; } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) { eigen_assert(false && "should never be called"); return -1; } #else template struct array_size > { static const ptrdiff_t value = Sizes::count; }; template struct array_size > { static const ptrdiff_t value = Sizes::count; }; template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes&) { return get::Base>::value; } #endif template struct sizes_match_below_dim { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { return false; } }; template struct sizes_match_below_dim { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1& dims1, Dims2& dims2) { return (array_get(dims1) == array_get(dims2)) && sizes_match_below_dim::run(dims1, dims2); } }; template struct sizes_match_below_dim { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { return true; } }; } // end namespace internal template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool dimensions_match(Dims1 dims1, Dims2 dims2) { return internal::sizes_match_below_dim::value, internal::array_size::value>::run(dims1, dims2); } } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H