diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-09-14 15:25:27 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-09-14 15:25:27 -0700 |
commit | 1b8d70a22b83d63667bbefe3899d9a2e0c2c8b78 (patch) | |
tree | e50af92d4d253a94d1e9cc87aa748e5c9a579014 /unsupported/Eigen/CXX11/src/Tensor | |
parent | 9b864cdb3789dbddaa26e53dd85393713b24ce94 (diff) |
Support reshaping with static shapes and dimensions conversion in tensor broadcasting
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 2 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h | 10 |
2 files changed, 11 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index e5cf93ab0..c102a43fb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -641,7 +641,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> return; } - const Dimensions& input_dims = m_impl.dimensions(); + const Dimensions& input_dims = Dimensions(m_impl.dimensions()); // Pre-fill input_block_sizes, broadcast_block_sizes, // broadcast_block_strides, and broadcast_tensor_strides. Later on we will diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 7c26b1682..fe0d57f31 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -290,6 +290,16 @@ struct DSizes : array<DenseIndex, NumDims> { } } +#ifdef EIGEN_HAS_INDEX_LIST + EIGEN_DEVICE_FUNC + template <typename FirstType, typename... OtherTypes> + DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) { + for (int i = 0; i < dimensions.count; ++i) { + (*this)[i] = dimensions[i]; + } + } +#endif + #ifndef EIGEN_EMULATE_CXX11_META_H template <typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) { |