diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-07-10 11:29:51 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-07-10 11:29:51 -0700 |
commit | 40bb98e76acbe6e077903e15896c100ee6cced39 (patch) | |
tree | 1afb999f8c46c7a76441df21ab93d5ea42cf7d06 /unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | |
parent | 9b7a6f0122f6817a3c12bc75803d4270cd9db507 (diff) |
Added primitives to compare tensor dimensions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 3e5687915..3b169a06f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -210,6 +210,60 @@ struct DSizes : array<DenseIndex, NumDims> { }; +namespace internal { + +template <typename DenseIndex, std::size_t NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > { + static const size_t value = NumDims; +}; +template <typename DenseIndex, std::size_t NumDims> struct array_size<DSizes<DenseIndex, NumDims> > { + static const size_t value = NumDims; +}; +#ifndef EIGEN_EMULATE_CXX11_META_H +template <typename std::size_t... Indices> struct array_size<const Sizes<Indices...> > { +static const size_t value = Sizes<Indices...>::count; +}; +template <typename std::size_t... Indices> struct array_size<Sizes<Indices...> > { +static const size_t value = Sizes<Indices...>::count; +}; +#else +template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > { + static const size_t value = Sizes<V1,V2,V3,V4,V5>::count; +}; +template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > { + static const size_t value = Sizes<V1,V2,V3,V4,V5>::count; +}; +template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<V1,V2,V3,V4,V5>& a) { + return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value; +}; + +#endif + + +template <typename Dims1, typename Dims2, size_t n> +struct sizes_match_up_to_dim { + static inline bool run(Dims1& dims1, Dims2& dims2) { + return (array_get<n>(dims1) == array_get<n>(dims2)) & + sizes_match_up_to_dim<Dims1, Dims2, n-1>::run(dims1, dims2); + } +}; +template <typename Dims1, typename Dims2> +struct sizes_match_up_to_dim<Dims1, Dims2, 0> { + static inline bool run(Dims1& dims1, Dims2& dims2) { + return (array_get<0>(dims1) == array_get<0>(dims2)); + } +}; + +template <typename Dims1, typename Dims2> +bool dimensions_match(Dims1& dims1, Dims2& dims2) { + if (array_size<Dims1>::value != array_size<Dims2>::value) { + return false; + } + return sizes_match_up_to_dim<Dims1, Dims2, array_size<Dims1>::value-1>::run(dims1, dims2); +} + +} // end namespace internal + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H |