diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-06-28 15:50:39 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-06-28 15:50:39 -0700 |
commit | 81a03bec75aac90aa343fccf6a7daf735e28c20d (patch) | |
tree | 487d4d27c6cd796b888246400c91540d87ba58dc /unsupported/Eigen/CXX11/src | |
parent | 8053eeb51e1735f3956f49555ac3901388c2ccca (diff) |
Fix TensorReverse on GPU with m_stride[i]==0
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h index 6faff87bf..42205db31 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h @@ -122,6 +122,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device RawAccess = false }; + typedef internal::TensorIntDivisor<Index> IndexDivisor; + typedef typename internal::remove_const<Scalar>::type ScalarNoConst; typedef internal::TensorBlock<ScalarNoConst, Index, NumDims, Layout> OutputTensorBlock; @@ -141,17 +143,15 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device m_strides[0] = 1; for (int i = 1; i < NumDims; ++i) { m_strides[i] = m_strides[i-1] * m_dimensions[i-1]; + if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]); } } else { m_strides[NumDims-1] = 1; for (int i = NumDims - 2; i >= 0; --i) { m_strides[i] = m_strides[i+1] * m_dimensions[i+1]; + if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]); } } - // Remember the strides for fast division. - for (int i = 0; i < NumDims; ++i) { - m_fastStrides[i] = internal::TensorIntDivisor<Index>(m_strides[i]); - } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -377,7 +377,7 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device protected: Dimensions m_dimensions; array<Index, NumDims> m_strides; - array<internal::TensorIntDivisor<Index>, NumDims> m_fastStrides; + array<IndexDivisor, NumDims> m_fastStrides; TensorEvaluator<ArgType, Device> m_impl; ReverseDimensions m_reverse; const Device& m_device; |