aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-28 10:02:47 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-28 10:02:47 -0800
commit5a6ea4edf61b5626a781070c6342fc16606b490a (patch)
tree2e94aad11b5ca76e48e17bce25979694441879bc
parent9dfdbd7e568bd3aa9a4610986dcfc679b9ea425d (diff)
Added more tests to cover tensor reductions
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h43
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h62
-rw-r--r--unsupported/test/cxx11_tensor_reduction.cpp37
3 files changed, 128 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index 7b8d34321..38586d067 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -37,7 +37,11 @@ template <typename T> struct SumReducer
return accum;
}
template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return vaccum;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return saccum + predux(vaccum);
}
};
@@ -45,16 +49,16 @@ template <typename T> struct SumReducer
template <typename T> struct MeanReducer
{
static const bool PacketAccess = true;
- MeanReducer() : count_(0) { }
+ MeanReducer() : scalarCount_(0), packetCount_(0) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
(*accum) += t;
- count_++;
+ scalarCount_++;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) {
(*accum) = padd<Packet>(*accum, p);
- count_ += packet_traits<Packet>::size;
+ packetCount_++;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
@@ -65,15 +69,20 @@ template <typename T> struct MeanReducer
return pset1<Packet>(0);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
- return accum / count_;
+ return accum / scalarCount_;
}
template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
- return (saccum + predux(vaccum)) / count_;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return pdiv(vaccum, pset1<Packet>(packetCount_));
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
+ return (saccum + predux(vaccum)) / (scalarCount_ + packetCount_ * packet_traits<Packet>::size);
}
protected:
- int count_;
+ int scalarCount_;
+ int packetCount_;
};
template <typename T> struct MaxReducer
@@ -99,7 +108,11 @@ template <typename T> struct MaxReducer
return accum;
}
template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return vaccum;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return (std::max)(saccum, predux_max(vaccum));
}
};
@@ -127,7 +140,11 @@ template <typename T> struct MinReducer
return accum;
}
template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return vaccum;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return (std::min)(saccum, predux_min(vaccum));
}
};
@@ -156,7 +173,11 @@ template <typename T> struct ProdReducer
return accum;
}
template <typename Packet>
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizePacket(const T saccum, const Packet& vaccum) const {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
+ return vaccum;
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return saccum * predux_mul(vaccum);
}
};
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
index 209749042..7ff47673d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
@@ -181,7 +181,7 @@ template<typename FirstType, typename... OtherTypes> size_t array_prod(const Ind
result *= sizes[i];
}
return result;
-}
+};
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
@@ -307,6 +307,52 @@ struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
};
+template <typename T>
+struct index_statically_gt {
+ constexpr bool operator() (DenseIndex, DenseIndex) const {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_gt<IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] > value;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_gt<const IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] > value;
+ }
+};
+
+template <typename T>
+struct index_statically_lt {
+ constexpr bool operator() (DenseIndex, DenseIndex) const {
+ return false;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_lt<IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] < value;
+ }
+};
+
+template <typename FirstType, typename... OtherTypes>
+struct index_statically_lt<const IndexList<FirstType, OtherTypes...> > {
+ constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
+ return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
+ IndexList<FirstType, OtherTypes...>()[i] < value;
+ }
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -351,6 +397,20 @@ struct index_statically_ne {
}
};
+template <typename T>
+struct index_statically_gt {
+ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
+ return false;
+ }
+};
+
+template <typename T>
+struct index_statically_lt {
+ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
+ return false;
+ }
+};
+
} // end namespace internal
} // end namespace Eigen
diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp
index 99e19eba4..5c3184833 100644
--- a/unsupported/test/cxx11_tensor_reduction.cpp
+++ b/unsupported/test/cxx11_tensor_reduction.cpp
@@ -369,6 +369,37 @@ static void test_innermost_first_dims() {
}
}
+template <int DataLayout>
+static void test_reduce_middle_dims() {
+ Tensor<float, 4, DataLayout> in(72, 53, 97, 113);
+ Tensor<float, 2, DataLayout> out(72, 53);
+ in.setRandom();
+
+// Reduce on the innermost dimensions.
+#if __cplusplus <= 199711L
+ array<int, 2> reduction_axis;
+ reduction_axis[0] = 1;
+ reduction_axis[1] = 2;
+#else
+ // This triggers the use of packets for RowMajor.
+ Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>> reduction_axis;
+#endif
+
+ out = in.maximum(reduction_axis);
+
+ for (int i = 0; i < 72; ++i) {
+ for (int j = 0; j < 113; ++j) {
+ float expected = -1e10f;
+ for (int k = 0; k < 53; ++k) {
+ for (int l = 0; l < 97; ++l) {
+ expected = (std::max)(expected, in(i, k, l, j));
+ }
+ }
+ VERIFY_IS_APPROX(out(i, j), expected);
+ }
+ }
+}
+
void test_cxx11_tensor_reduction() {
CALL_SUBTEST(test_simple_reductions<ColMajor>());
CALL_SUBTEST(test_simple_reductions<RowMajor>());
@@ -380,8 +411,10 @@ void test_cxx11_tensor_reduction() {
CALL_SUBTEST(test_tensor_maps<RowMajor>());
CALL_SUBTEST(test_static_dims<ColMajor>());
CALL_SUBTEST(test_static_dims<RowMajor>());
- CALL_SUBTEST(test_innermost_last_dims<RowMajor>());
CALL_SUBTEST(test_innermost_last_dims<ColMajor>());
- CALL_SUBTEST(test_innermost_first_dims<RowMajor>());
+ CALL_SUBTEST(test_innermost_last_dims<RowMajor>());
CALL_SUBTEST(test_innermost_first_dims<ColMajor>());
+ CALL_SUBTEST(test_innermost_first_dims<RowMajor>());
+ CALL_SUBTEST(test_reduce_middle_dims<ColMajor>());
+ CALL_SUBTEST(test_reduce_middle_dims<RowMajor>());
}