aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-04-14 13:23:35 -0700
committerGravatar RJ Ryan <rjryan@google.com>2017-04-14 13:23:35 -0700
commit949a2da38cbfebe358a25dc59b47abb67beb4126 (patch)
treefc5e2b7b0bf11bd52f29c30d18d95d391a2a8fe3 /unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
parentd9084ac8e142697f0d767092a17ffc3a7a18a2e4 (diff)
Use scalar_sum_op and scalar_quotient_op instead of operator+ and operator/ in MeanReducer.
Improves support for std::complex types when compiling for CUDA. Expands on e2e9cdd16970914cf0a892fea5e7c4402b3ede41 and 2bda1b0d93fb627d0c500ec48b20302d44c32cb7 .
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h8
1 files changed, 6 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index 3b4f8eda1..5dcc3794c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -166,7 +166,8 @@ template <typename T> struct MeanReducer
return pset1<Packet>(initialize());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
- return accum / scalarCount_;
+ internal::scalar_quotient_op<T> quotient_op;
+ return quotient_op(accum, T(scalarCount_));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
@@ -175,7 +176,10 @@ template <typename T> struct MeanReducer
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
internal::scalar_sum_op<T> sum_op;
- return sum_op(saccum, predux(vaccum)) / (scalarCount_ + packetCount_ * unpacket_traits<Packet>::size);
+ internal::scalar_quotient_op<T> quotient_op;
+ return quotient_op(
+ sum_op(saccum, predux(vaccum)),
+ T(scalarCount_ + packetCount_ * unpacket_traits<Packet>::size));
}
protected: