aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-14 15:25:27 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-14 15:25:27 -0700
commit1b8d70a22b83d63667bbefe3899d9a2e0c2c8b78 (patch)
treee50af92d4d253a94d1e9cc87aa748e5c9a579014
parent9b864cdb3789dbddaa26e53dd85393713b24ce94 (diff)
Support reshaping with static shapes and dimensions conversion in tensor broadcasting
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h2
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h10
-rw-r--r--unsupported/test/cxx11_tensor_broadcasting.cpp2
-rw-r--r--unsupported/test/cxx11_tensor_morphing.cpp23
4 files changed, 35 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index e5cf93ab0..c102a43fb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -641,7 +641,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
return;
}
- const Dimensions& input_dims = m_impl.dimensions();
+ const Dimensions& input_dims = Dimensions(m_impl.dimensions());
// Pre-fill input_block_sizes, broadcast_block_sizes,
// broadcast_block_strides, and broadcast_tensor_strides. Later on we will
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
index 7c26b1682..fe0d57f31 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
@@ -290,6 +290,16 @@ struct DSizes : array<DenseIndex, NumDims> {
}
}
+#ifdef EIGEN_HAS_INDEX_LIST
+ EIGEN_DEVICE_FUNC
+ template <typename FirstType, typename... OtherTypes>
+ DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
+ for (int i = 0; i < dimensions.count; ++i) {
+ (*this)[i] = dimensions[i];
+ }
+ }
+#endif
+
#ifndef EIGEN_EMULATE_CXX11_META_H
template <typename std::ptrdiff_t... Indices>
EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) {
diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp
index 2f8ab6afd..7df5b53d6 100644
--- a/unsupported/test/cxx11_tensor_broadcasting.cpp
+++ b/unsupported/test/cxx11_tensor_broadcasting.cpp
@@ -115,7 +115,7 @@ static void test_static_broadcasting()
Tensor<float, 3, DataLayout> tensor(8,3,5);
tensor.setRandom();
-#if EIGEN_HAS_CONSTEXPR
+#if defined(EIGEN_HAS_INDEX_LIST)
Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
#else
Eigen::array<int, 3> broadcasts;
diff --git a/unsupported/test/cxx11_tensor_morphing.cpp b/unsupported/test/cxx11_tensor_morphing.cpp
index 6365cd89a..4cbe15b63 100644
--- a/unsupported/test/cxx11_tensor_morphing.cpp
+++ b/unsupported/test/cxx11_tensor_morphing.cpp
@@ -41,6 +41,28 @@ static void test_simple_reshape()
}
}
+template <typename>
+static void test_static_reshape() {
+#if defined(EIGEN_HAS_INDEX_LIST)
+ using Eigen::type2index;
+
+ Tensor<float, 5> tensor(2, 3, 1, 7, 1);
+ tensor.setRandom();
+
+ // New dimensions: [2, 3, 7]
+ Eigen::IndexList<type2index<2>, type2index<3>, type2index<7>> dim;
+ Tensor<float, 3> reshaped = tensor.reshape(dim);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_EQUAL(tensor(i, j, 0, k, 0), reshaped(i, j, k));
+ }
+ }
+ }
+#endif
+}
+
template<typename>
static void test_reshape_in_expr() {
MatrixXf m1(2,3*5*7*11);
@@ -462,6 +484,7 @@ static void test_composition()
EIGEN_DECLARE_TEST(cxx11_tensor_morphing)
{
CALL_SUBTEST_1(test_simple_reshape<void>());
+ CALL_SUBTEST_1(test_static_reshape<void>());
CALL_SUBTEST_1(test_reshape_in_expr<void>());
CALL_SUBTEST_1(test_reshape_as_lvalue<void>());