diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-02-19 10:05:59 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-02-19 10:05:59 -0800 |
commit | 180156ba1aefceae0bd93f056e5807a83ccbb1b5 (patch) | |
tree | de29c755cead6e1d6b1c55c58f78c003bd6799f6 /unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | |
parent | 5c4901b83a3ec15988521e195abc05e804c541dc (diff) |
Added support for tensor reductions on half floats
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index f94ffa020..e2d876140 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -72,11 +72,12 @@ template <typename T> struct SumReducer } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return static_cast<T>(0); + internal::scalar_cast_op<int, T> conv; + return conv(0); } template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { - return pset1<Packet>(0); + return pset1<Packet>(initialize()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { return accum; @@ -110,11 +111,12 @@ template <typename T> struct MeanReducer } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return static_cast<T>(0); + internal::scalar_cast_op<int, T> conv; + return conv(0); } template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { - return pset1<Packet>(0); + return pset1<Packet>(initialize()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { return accum / scalarCount_; @@ -214,11 +216,12 @@ template <typename T> struct ProdReducer } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return static_cast<T>(1); + internal::scalar_cast_op<int, T> conv; + return conv(1); } template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { - return pset1<Packet>(1); + return pset1<Packet>(initialize()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { return accum; |