aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h72
-rw-r--r--unsupported/test/cxx11_tensor_reduction.cpp33
2 files changed, 101 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index 92984336c..e9aa22183 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -25,12 +25,12 @@ template <typename T> struct SumReducer
}
private:
- T m_sum;
+ typename internal::remove_all<T>::type m_sum;
};
template <typename T> struct MaxReducer
{
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaxReducer() : m_max((std::numeric_limits<T>::min)()) { }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaxReducer() : m_max(-(std::numeric_limits<T>::max)()) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t) {
if (t > m_max) { m_max = t; }
}
@@ -39,7 +39,7 @@ template <typename T> struct MaxReducer
}
private:
- T m_max;
+ typename internal::remove_all<T>::type m_max;
};
template <typename T> struct MinReducer
@@ -53,9 +53,73 @@ template <typename T> struct MinReducer
}
private:
- T m_min;
+ typename internal::remove_all<T>::type m_min;
};
+
+#if !defined (EIGEN_USE_GPU) || !defined(__CUDACC__) || !defined(__CUDA_ARCH__)
+// We're not compiling a cuda kernel
+template <typename T> struct UniformRandomGenerator {
+ template<typename Index>
+ T operator()(Index, Index = 0) const {
+ return random<T>();
+ }
+ template<typename Index>
+ typename internal::packet_traits<T>::type packetOp(Index, Index = 0) const {
+ const int packetSize = internal::packet_traits<T>::size;
+ EIGEN_ALIGN_DEFAULT T values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] = random<T>();
+ }
+ return internal::pload<typename internal::packet_traits<T>::type>(values);
+ }
+};
+
+#else
+
+// We're compiling a cuda kernel
+template <typename T> struct UniformRandomGenerator;
+
+template <> struct UniformRandomGenerator<float> {
+ UniformRandomGenerator() {
+ const int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ curand_init(0, tid, 0, &m_state);
+ }
+
+ template<typename Index>
+ float operator()(Index, Index = 0) const {
+ return curand_uniform(&m_state);
+ }
+ template<typename Index>
+ float4 packetOp(Index, Index = 0) const {
+ return curand_uniform4(&m_state);
+ }
+
+ private:
+ mutable curandStatePhilox4_32_10_t m_state;
+};
+
+template <> struct UniformRandomGenerator<double> {
+ UniformRandomGenerator() {
+ const int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ curand_init(0, tid, 0, &m_state);
+ }
+ template<typename Index>
+ double operator()(Index, Index = 0) const {
+ return curand_uniform_double(&m_state);
+ }
+ template<typename Index>
+ double2 packetOp(Index, Index = 0) const {
+ return curand_uniform2_double(&m_state);
+ }
+
+ private:
+ mutable curandStatePhilox4_32_10_t m_state;
+};
+
+#endif
+
+
} // end namespace internal
} // end namespace Eigen
diff --git a/unsupported/test/cxx11_tensor_reduction.cpp b/unsupported/test/cxx11_tensor_reduction.cpp
index 27135b982..da9885166 100644
--- a/unsupported/test/cxx11_tensor_reduction.cpp
+++ b/unsupported/test/cxx11_tensor_reduction.cpp
@@ -139,9 +139,42 @@ static void test_user_defined_reductions()
}
+static void test_tensor_maps()
+{
+ int inputs[2*3*5*7];
+ TensorMap<Tensor<int, 4> > tensor_map(inputs, 2,3,5,7);
+ TensorMap<Tensor<const int, 4> > tensor_map_const(inputs, 2,3,5,7);
+ const TensorMap<Tensor<const int, 4> > tensor_map_const_const(inputs, 2,3,5,7);
+
+ tensor_map.setRandom();
+ array<ptrdiff_t, 2> reduction_axis;
+ reduction_axis[0] = 1;
+ reduction_axis[1] = 3;
+
+ Tensor<int, 2> result = tensor_map.sum(reduction_axis);
+ Tensor<int, 2> result2 = tensor_map_const.sum(reduction_axis);
+ Tensor<int, 2> result3 = tensor_map_const_const.sum(reduction_axis);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ int sum = 0;
+ for (int k = 0; k < 3; ++k) {
+ for (int l = 0; l < 7; ++l) {
+ sum += tensor_map(i, k, j, l);
+ }
+ }
+ VERIFY_IS_EQUAL(result(i, j), sum);
+ VERIFY_IS_EQUAL(result2(i, j), sum);
+ VERIFY_IS_EQUAL(result3(i, j), sum);
+ }
+ }
+}
+
+
void test_cxx11_tensor_reduction()
{
CALL_SUBTEST(test_simple_reductions());
CALL_SUBTEST(test_full_reductions());
CALL_SUBTEST(test_user_defined_reductions());
+ CALL_SUBTEST(test_tensor_maps());
}